# Dependencies

In [None]:
from google.colab import drive
drive.mount("/content/drive")

!pip install torch
!pip install datasets
!pip install zss
!pip install lxml
!pip install huggingface_hub
!pip install anthropic
!pip install transformers==4.47.0
!pip install python-Levenshtein
!pip install pytesseract
! apt install tesseract-ocr
! apt install libtesseract-dev

# Accuracy

In [None]:
import re
import Levenshtein
import numpy as np
from collections import defaultdict

def parse_tagged_string(input_string):
    """
    Parses a string with tags like <page_header>, <paragraph>, <subheader> into a structured dictionary.
    Handles unclosed or malformed tags.

    Args:
        input_string (str): The input string.

    Returns:
        dict: A dictionary representing the structured data, or None if parsing fails.
    """
    try:
        # Remove <s>, </s>, and <sep/> tags
        input_string = (
            input_string.replace("<s>", "").replace("</s>", "").replace("<sep/>", "")
        )
        tag_positions = []
        for match in re.finditer(r"<(/?)(\w+)>", input_string):
            tag_positions.append(
                (match.start(), match.group(1) == "/", match.group(2))
            )

        elements = []
        start = 0
        current_tag = None

        for pos, is_closing, tag_name in tag_positions:
            if not is_closing:
                # If the model didn't output closing tag
                if current_tag is not None:
                    content = input_string[start:pos].strip()
                    if content and current_tag in [
                        "page_header",
                        "paragraph",
                        "subheading",
                        "title",
                        "image",
                        "table",
                        "code_snippet",
                        "page_footer",
                    ]:
                        elements.append({"type": current_tag, "content": content})

                current_tag = tag_name
                start = pos + len(f"<{tag_name}>")
            else:
                if current_tag == tag_name:
                    content = input_string[start:pos].strip()
                    if content and current_tag in [
                        "page_header",
                        "paragraph",
                        "subheading",
                        "title",
                        "image",
                        "table",
                        "code_snippet",
                        "page_footer",
                    ]:
                        elements.append({"type": current_tag, "content": content})

                    current_tag = None
                    start = pos + len(f"</{tag_name}>")

        if current_tag is not None and start < len(input_string):
            content = input_string[start:].strip()
            if content and current_tag in [
                "page_header",
                "paragraph",
                "subheading",
                "title",
                "image",
                "table",
                "code_snippet",
                "page_footer",
            ]:
                elements.append({"type": current_tag, "content": content})

        return {"document": elements}

    except Exception as e:
        print(f"Error during parsing: {e}")
        return None

def safe_divide(numerator, denominator):
    """
    Performs division, returning 0 if the denominator is 0.

    Args:
        numerator: The numerator.
        denominator: The denominator.

    Returns:
        The result of the division or 0 if the denominator is 0.
    """
    if denominator == 0:
        return 0
    else:
        return numerator / denominator

def calculate_text_accuracy(ref_text, pred_text):
    """
    Calculates text accuracy using Levenshtein distance.

    Args:
        ref_text (str): Reference text.
        pred_text (str): Predicted text.

    Returns:
        tuple: (accuracy, word_count) where accuracy is a float between 0 and 1,
               and word_count is the number of words in the reference text.
    """
    if not ref_text:
        if not pred_text:
            return 1.0, 0  # Both empty is considered 100% accurate
        else:
            return 0.0, 0
    if not isinstance(ref_text, str) or not isinstance(pred_text, str):
        return 0.0, 0
    dist = Levenshtein.distance(ref_text, pred_text)
    ref_word_count = len(ref_text.split())
    accuracy = safe_divide(max(0, len(ref_text) - dist), len(ref_text))

    return accuracy, ref_word_count

def calculate_overall_text_extraction_accuracy(reference, prediction):
    """
    Calculates the overall text extraction accuracy by comparing the content of two structured documents.
    It handles cases where the prediction may have extra text elements.

    Args:
        reference (dict): A dictionary representing the reference document structure.
        prediction (dict): A dictionary representing the predicted document structure.

    Returns:
        float: The overall text extraction accuracy, a float between 0 and 1.
    """
    total_accuracy = 0
    total_words = 0

    ref_texts = [item["content"] for item in reference["document"]]
    pred_texts = [item["content"] for item in prediction["document"]]

    for ref_text, pred_text in zip(ref_texts, pred_texts):
        accuracy, word_count = calculate_text_accuracy(ref_text, pred_text)
        total_accuracy += accuracy * word_count
        total_words += word_count

    for pred_text in pred_texts[len(ref_texts):]:
        accuracy, word_count = calculate_text_accuracy("", pred_text)
        total_accuracy += accuracy * word_count
        total_words += word_count

    overall_accuracy = safe_divide(total_accuracy, total_words)

    return overall_accuracy

def calculate_tag_categorization_accuracy(reference, prediction):
    """
    Calculates the accuracy of tag categorization between two structured documents.
    It provides both overall tag accuracy and per-tag accuracy.
    It uses a predefined list of tags and handles cases where tags are not present in the reference document.

    Args:
        reference (dict): A dictionary representing the reference document structure.
        prediction (dict): A dictionary representing the predicted document structure.

    Returns:
        tuple: (tag_accuracy_dict, overall_accuracy) where tag_accuracy_dict 
        is a dictionary where keys are tag names and 
        values are their accuracies (float between 0 and 1, or -9999 if not present in the reference) and 
        where overall_accuracy is the overall tag categorization accuracy (float between 0 and 1).
    """
    ref_tags = [item["type"] for item in reference["document"]]
    pred_tags = [item["type"] for item in prediction["document"]]

    # Predefined list of tags
    all_tags = [
        "page_header",
        "paragraph",
        "subheading",
        "title",
        "image",
        "table",
        "code_snippet",
        "page_footer",
    ]

    tag_accuracy_dict = {}

    for tag in all_tags:
        correct_for_tag = sum(
            1
            for r_tag, p_tag in zip(ref_tags, pred_tags)
            if r_tag == p_tag and r_tag == tag
        )

        total_for_tag_ref = ref_tags.count(tag)

        if total_for_tag_ref > 0:
            tag_accuracy_dict[tag] = safe_divide(correct_for_tag, total_for_tag_ref)
        else:
            tag_accuracy_dict[tag] = -9999  # Use -9999 for tags not present in reference so that later it can display as nan if tags not evaluated

    correct_tags = sum(1 for r_tag, p_tag in zip(ref_tags, pred_tags) if r_tag == p_tag)
    overall_accuracy = safe_divide(
        correct_tags, max(len(ref_tags), len(pred_tags))
    )

    return tag_accuracy_dict, overall_accuracy

def calculate_text_accuracy_for_tag(reference, prediction, tag):
    """
    Calculates text extraction accuracy for a specific tag between two structured documents.
    It handles cases where a tag is missing in either the reference or the prediction.

    Args:
        reference (dict): A dictionary representing the reference document structure.
        prediction (dict): A dictionary representing the predicted document structure.
        tag (str): The tag for which to calculate text extraction accuracy.

    Returns:
        float: The text extraction accuracy for the specified tag (float between 0 and 1, or -9999 if the tag is not evaluated).
    """
    total_accuracy = 0
    total_words = 0

    ref_texts = [item["content"] for item in reference["document"] if item["type"] == tag]
    pred_texts = [item["content"] for item in prediction["document"] if item["type"] == tag]

    # If the tag is missing in both reference and prediction, return -9999 because it's not evaluated
    if not ref_texts and not pred_texts:
        return -9999

    for ref_text, pred_text in zip(ref_texts, pred_texts):
        accuracy, word_count = calculate_text_accuracy(ref_text, pred_text)
        total_accuracy += accuracy * word_count
        total_words += word_count

    for ref_text in ref_texts[len(pred_texts):]:
        accuracy, word_count = calculate_text_accuracy(ref_text, "")
        total_accuracy += accuracy * word_count
        total_words += word_count

    for pred_text in pred_texts[len(ref_texts):]:
        accuracy, word_count = calculate_text_accuracy("", pred_text)
        total_accuracy += accuracy * word_count
        total_words += word_count

    accuracy_for_tag = safe_divide(total_accuracy, total_words)

    return accuracy_for_tag

def calculate_all_accuracies(reference, prediction):
    """
    Calculates overall text extraction accuracy, tag categorization accuracy, and text extraction accuracy for each tag.

    Args:
        reference (dict): A dictionary representing the reference document structure.
        prediction (dict): A dictionary representing the predicted document structure.

    Returns:
        dict: A dictionary containing the accuracy scores, structured as follows:
              - overall_text_extraction: A dictionary containing the overall text extraction accuracy.
              - tag_categorization: A dictionary containing overall tag accuracy and per-tag accuracy.
              - text_extraction_by_tag: A dictionary containing text extraction accuracy for each tag.
    """
    overall_text_accuracy = calculate_overall_text_extraction_accuracy(
        reference, prediction
    )
    tag_accuracy_dict, overall_tag_accuracy = calculate_tag_categorization_accuracy(
        reference, prediction
    )

    tag_accuracies = {}
    for tag in [
        "paragraph",
        "subheading",
        "page_footer",
        "title",
        "image",
        "table",
        "page_header",
        "code_snippet",
    ]:
        tag_accuracies[tag] = calculate_text_accuracy_for_tag(
            reference, prediction, tag
        )

    return {
        "overall_text_extraction": {"percentage": overall_text_accuracy},
        "tag_categorization": {"overall_tag_accuracy": overall_tag_accuracy,
                               "tag_accuracy_dict": tag_accuracy_dict},
        "text_extraction_by_tag": {
            tag: {"percentage": tag_accuracies[tag]} for tag in tag_accuracies
        },
    }

def final_accuracy(reference_string, prediction_string):
    """
    Calculates accuracy metrics from tagged input strings by first parsing them into structured dictionaries.

    Args:
        reference_string (str): The reference string with tags.
        prediction_string (str): The prediction string with tags.

    Returns:
        dict: A dictionary containing the accuracy scores, or None if parsing fails.
    """
    reference_data = parse_tagged_string(reference_string)
    prediction_data = parse_tagged_string(prediction_string)

    if reference_data is None or prediction_data is None:
        print("Error: Could not parse input strings.")
        return None

    return calculate_all_accuracies(reference_data, prediction_data)

def avg_accuracy(all_accs):
    """
    Calculates the average accuracy scores from a list of accuracy dictionaries.
    Handles missing tags by excluding them from the average calculation.

    Args:
        all_accs (list): A list of dictionaries, each representing the output of the `final_accuracy` function.

    Returns:
        dict: A dictionary containing the average accuracy scores, structured similarly to the output of `calculate_all_accuracies`.
    """
    overall_text_accuracies = []
    tag_categorization_accuracies = []
    tag_accuracy_dicts = []
    tag_level_accuracies = defaultdict(list)

    for acc_dict in all_accs:
        overall_text_accuracies.append(acc_dict["overall_text_extraction"]["percentage"])
        tag_categorization_accuracies.append(
            acc_dict["tag_categorization"]["overall_tag_accuracy"]
        )
        tag_accuracy_dicts.append(acc_dict["tag_categorization"]["tag_accuracy_dict"])  # Store the dict

        for tag, metrics in acc_dict["text_extraction_by_tag"].items():
            tag_level_accuracies[tag].append(metrics["percentage"])

    # Calculate averages, skipping values less than -9998 (-9999)
    avg_overall_text_accuracy = np.mean(
        [x for x in overall_text_accuracies if x > -9998]
    )
    avg_tag_categorization_accuracy = np.mean(
        [x for x in tag_categorization_accuracies if x > -9998]
    )

    avg_tag_accuracies = {}
    for tag in tag_accuracy_dicts[0]:
        avg_tag_accuracies[tag] = np.mean(
            [tag_dict[tag] for tag_dict in tag_accuracy_dicts if tag_dict[tag] > -9998]
        )

    avg_tag_level_accuracies = {}
    for tag, values in tag_level_accuracies.items():
        avg_tag_level_accuracies[tag] = np.mean([x for x in values if x > -9998])

    return {
        "overall_text_extraction": {"percentage": avg_overall_text_accuracy},
        "tag_categorization": {
            "percentage": avg_tag_categorization_accuracy,
            "tag_accuracy_dict": avg_tag_accuracies
        },
        "text_extraction_by_tag": {
            tag: {"percentage": avg_tag_level_accuracies[tag]}
            for tag, value in avg_tag_level_accuracies.items()
        },
    }

# PNG -> XML

In [None]:
from PIL import Image
import pytesseract
import json
import os
from lxml import etree
from collections import defaultdict

pytesseract.pytesseract.tesseract_cmd = r'/usr/bin/tesseract'
allowed_categories = {
    "paragraph", "image", "title", "table", "page_header", "subheading", "code_snippet", "page_footer"
}

def extract_text_from_box(image_path, bbox):
    """
    Extracts text from a specified region (bounding box) of an image using Tesseract OCR.

    Args:
        image_path (str): The path to the image file.
        bbox (tuple): A tuple representing the bounding box (left, top, width, height).

    Returns:
        str: The extracted text from the bounding box, or an empty string if an error occurs or no text is found.
    """
    try:
        image = Image.open(image_path)
        left, top, width, height = bbox
        cropped_image = image.crop((left, top, left + width, top + height))
        ocr_text = pytesseract.image_to_string(cropped_image, config='--oem 1')  # Added config for improved accuracy
        return ocr_text.strip() if ocr_text else ""
    except Exception as e:
        print(f"Error during OCR: {e}")
        return ""

def create_xml_for_image(image_info, coco_annotations, xml_output_path, image_folder):
    """
    Creates an XML file for a single image based on COCO annotations, including only the allowed categories and ordering elements by position.

    Args:
        image_info (dict): A dictionary containing image information from the COCO annotations.
        coco_annotations (dict): The loaded COCO annotations dictionary.
        xml_output_path (str): The path to save the generated XML file.
        image_folder (str): The path to the folder containing the images.
    """
    image_id = image_info['id']
    file_name = image_info['file_name']
    image_path = os.path.join(image_folder, file_name)
    if not os.path.exists(image_path):
        print(f"Image file not found: {image_path}, skipping.")
        return
    if os.path.exists(xml_output_path):
        print(f"XML file already exists for {file_name}, skipping.")
        return
    root = etree.Element("document")
    page = etree.SubElement(root, "page")
    page.set("file_name", file_name)
    annotations_for_image = [
        ann for ann in coco_annotations['annotations']
        if ann['image_id'] == image_id and any(
            cat['id'] == ann['category_id'] and cat['name'] in allowed_categories
            for cat in coco_annotations['categories']
        )
    ]
    annotations_for_image.sort(key=lambda ann: (ann['bbox'][1], ann['bbox'][0]))
    for annotation in annotations_for_image:
        category_id = annotation['category_id']
        category_name = next(
            (cat['name'] for cat in coco_annotations['categories'] if cat['id'] == category_id),
            None
        )
        ocr_text = extract_text_from_box(image_path, annotation['bbox'])
        if ocr_text and category_name in allowed_categories:
            category_element = etree.SubElement(page, category_name)
            category_element.text = ocr_text
    tree = etree.ElementTree(root)
    tree.write(xml_output_path, pretty_print=True, encoding="utf-8")  # Added encoding for consistency
    print(f"XML file saved: {xml_output_path}")

def load_coco_annotations(coco_json_path):
    """
    Loads COCO annotations from a JSON file.

    Args:
        coco_json_path (str): The path to the COCO JSON file.

    Returns:
        dict: The loaded COCO annotations as a dictionary.
    """
    with open(coco_json_path) as coco_file:
        coco_annotations = json.load(coco_file)
    return coco_annotations

def process_all_images(image_folder, coco_json_path, output_xml_folder):
    """
    Processes all images in a folder, creating corresponding XML files based on COCO annotations.

    Args:
        image_folder (str): The path to the folder containing the images.
        coco_json_path (str): The path to the COCO JSON file.
        output_xml_folder (str): The path to the folder where the XML files will be saved.
    """
    coco_annotations = load_coco_annotations(coco_json_path)
    os.makedirs(output_xml_folder, exist_ok=True)
    image_info_dict = {img['file_name']: img for img in coco_annotations['images']}

    for image_file in os.listdir(image_folder):
        if image_file.endswith('.png'):
            image_info = image_info_dict.get(image_file)
            if image_info:
                xml_output_path = os.path.join(output_xml_folder, image_file.replace('.png', '.xml'))
                create_xml_for_image(image_info, coco_annotations, xml_output_path, image_folder)

# Example usage (update with your paths):
image_folder = "/content/drive/MyDrive/Unstructured Project/Unstructured_Jay/code/all_images/test_images"
output_xml_folder = "/content/drive/MyDrive/Unstructured Project/Unstructured_Jay/code/all_images/test_xml"
coco_json_path = '/content/drive/MyDrive/Unstructured Project/Unstructured_Jay/code/all_images/all.json'

process_all_images(image_folder, coco_json_path, output_xml_folder)

# Train Donut

In [None]:
import json
import os
import shutil
import torch
from PIL import Image
import os
import json
import xml.etree.ElementTree as ET
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel, Seq2SeqTrainingArguments, Seq2SeqTrainer
import shutil
from datasets import Dataset, Features, Value, Array3D, Sequence
from huggingface_hub import HfFolder, login
import subprocess
import torch
from google.colab import userdata

login(userdata.get('huggingfaceapi'))
PYDEVD_DISABLE_FILE_VALIDATION=1
TOKENIZERS_PARALLELISM=False
Image.MAX_IMAGE_PIXELS = 130436144 #Some images had too many pixels; this bypasses that restriction.
base_model = "naver-clova-ix/donut-base"
finetuned_model = "nielsr/donut-docvqa-demo"
saved_model = "JayJai/donut_gpu"
processor = DonutProcessor.from_pretrained(saved_model)
model = VisionEncoderDecoderModel.from_pretrained(saved_model)
config = DonutProcessor.from_pretrained(saved_model)

def extract_text_from_xml(xml_path):
    """
    Extracts text content from an XML file and returns it as a list of dictionaries.

    Each dictionary in the list represents an element in the XML with text content,
    where the key is the tag name and the value is the text content of the element.

    Args:
        xml_path (str): The path to the XML file.

    Returns:
        list: A list of dictionaries, where each dictionary represents an element with text content.
              Returns an empty list if the file does not exist or if no text content is found.
    """
    if not os.path.exists(xml_path):
        print(f"Error: XML file not found at {xml_path}")
        return []
    try:
      tree = ET.parse(xml_path)
      root = tree.getroot()
      text_dict = []
      for elem in root.iter():
          tag = elem.tag
          text = elem.text.strip() if elem.text else ""
          if text:
              text_dict.append({tag: text})

      return text_dict
    except ET.ParseError as e:
      print(f"Error parsing XML file: {e}")
      return []

def create_json(image_folder, xml_folder):
    """
    Creates a JSON structure mapping images to their corresponding XML text content.

    Iterates through each XML file in the specified folder, extracts the text content,
    and associates it with the corresponding image file (PNG format).

    Args:
        image_folder (str): The path to the folder containing the image files.
        xml_folder (str): The path to the folder containing the XML files.

    Returns:
        list: A list of dictionaries, where each dictionary contains:
              - "file_name": The name of the image file.
              - "text": A list of dictionaries representing the text content extracted from the corresponding XML file.
    """
    json_data = []
    for xml_file in os.listdir(xml_folder):
        if xml_file.endswith(".xml"):
            image_file = os.path.splitext(xml_file)[0] + ".png"
            if os.path.exists(os.path.join(image_folder, image_file)):
                xml_path = os.path.join(xml_folder, xml_file)
                text = extract_text_from_xml(xml_path)
                json_data.append({
                    "file_name": image_file,
                    "text": text
                })
    return json_data

def json2token(obj, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
    """
    Converts a JSON object (dictionary or list) into a token sequence with special tokens.

    Recursively processes the JSON object, adding special tokens for keys and values.
    Also updates a global list `new_special_tokens` with new special tokens encountered.

    Args:
        obj (dict or list): The JSON object to convert.
        update_special_tokens_for_json_key (bool, optional): Whether to update the global `new_special_tokens` list. Defaults to True.
        sort_json_key (bool, optional): Whether to sort dictionary keys in reverse order. Defaults to True.

    Returns:
        str: The token sequence representation of the JSON object.
    """
    if type(obj) == dict:
        if len(obj) == 1 and "text_sequence" in obj:
            return obj["text_sequence"]
        else:
            output = ""
            if sort_json_key:
                keys = sorted(obj.keys(), reverse=True)
            else:
                keys = obj.keys()
            for k in keys:
                if update_special_tokens_for_json_key:
                    new_special_tokens.append(fr"<{k}>") if fr"<{k}>" not in new_special_tokens else None
                    new_special_tokens.append(fr"</{k}>") if fr"</{k}>" not in new_special_tokens else None
                output += (
                    fr"<{k}>"
                    + json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
                    + fr"</{k}>"
                )
            return output
    elif type(obj) == list:
        return r"<sep/>".join(
            [json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
        )
    else:
        obj = str(obj)
        if f"<{obj}/>" in new_special_tokens:
            obj = f"<{obj}/>"
        return obj

def preprocess_documents_for_donut(sample):
    """
    Preprocesses a sample document for Donut model, converting it into a token sequence and adding task start and end tokens.

    Args:
        sample (str): A JSON string representing the document with "file_name" and "text" fields.

    Returns:
        dict: A dictionary containing the image file name and the processed text sequence.
    """
    sample = json.loads(sample)
    text = sample["text"]
    d_doc = task_start_token + json2token(text) + eos_token
    # convert all images to RGB
    return {"image": sample["file_name"], "text": d_doc}

def process_json(input_file, output_file):
    """
    Processes a JSON file containing document data, transforming it into a format suitable for Donut model training.

    Reads a JSON file where each line is a JSON object with "file_name" and "text" fields.
    Applies the `preprocess_documents_for_donut` function to each object and saves the processed data to a new JSON file.

    Args:
        input_file (str): The path to the input JSON file.
        output_file (str): The path to save the processed JSON file.
    """
    with open(input_file, 'r') as infile:
        data = json.load(infile)
    # Check if data is a list of dictionaries
    if not isinstance(data, list):
        print("Error: The JSON file does not contain a list.")
        return
    processed_data = []

    # Iterate through each JSON object in the list
    for item in data:
        if 'file_name' in item and 'text' in item:
            processed_item = preprocess_documents_for_donut(json.dumps(item))
            processed_data.append(processed_item)
        else:
            print(f"Warning: Skipping invalid item: {item}")
    with open(output_file, 'w') as outfile:
        json.dump(processed_data, outfile, indent=4)

    print(f"Processed JSON saved to {output_file}")

def transform_and_tokenize(sample, processor=processor, training_split = 'train', max_length=1536, ignore_id=-100, image_folder = 'error'):
    """
    Transforms and tokenizes a sample document for Donut model training.

    Processes an image by resizing and converting it to RGB, then tokenizes the corresponding text sequence.
    Handles potential errors during image processing.

    Args:
        sample (dict): A dictionary containing "image" and "text" keys.
        processor (DonutProcessor): The Donut processor for image and text processing.
        training_split (str, optional): The type of dataset split ('train', 'eval', 'test'). Defaults to 'train'.
        max_length (int, optional): The maximum length for tokenization. Defaults to 1536.
        ignore_id (int, optional): The ID to use for padding tokens. Defaults to -100.
        image_folder (str, optional): The base path of the folder containing images. Defaults to 'error'.

    Returns:
        dict: A dictionary containing pixel values, labels, target sequence, and image path, or an empty dictionary if an error occurs.
    """
    try:
        image = Image.open(image_folder + '/' + sample["image"]).resize((960, 1280))
        pixel_values = processor(image.convert('RGB'), return_tensors="pt").pixel_values.squeeze()
    except Exception as e:
        print(f"Error: {e}")
        return {}

    # tokenize document
    input_ids = processor.tokenizer(sample["text"],add_special_tokens=False,max_length=max_length,padding="max_length",truncation=True,return_tensors="pt",)["input_ids"].squeeze(0)

    labels = input_ids.clone()
    labels[labels == processor.tokenizer.pad_token_id] = ignore_id # model doesn't need to predict pad token
    return {"pixel_values": pixel_values, "labels": labels, "target_sequence": sample["text"], "image_folder": image_folder + '/' + sample["image"]}

def data_generator(input_file, start_index, end_index, image_folder):
    """
    Generates transformed and tokenized data for Donut model training from a JSON file.

    Yields processed samples within a specified index range from the input JSON file.

    Args:
        input_file (str): The path to the input JSON file.
        start_index (int): The starting index of the data to process.
        end_index (int): The ending index of the data to process.
        image_folder (str): The base path of the folder containing images.

    Yields:
        dict: A dictionary containing processed data for a single sample (output of `transform_and_tokenize`).
    """
    with open(input_file, 'r') as infile:
        data = json.load(infile)

    for i in range(start_index, end_index):
        item = data[i]
        if isinstance(item, dict) and 'image' in item and 'text' in item:
            yield transform_and_tokenize(item, image_folder = image_folder)
        else:
            print(f"Warning: Skipping invalid item: {item}")

# Dataset feature types
your_dataset_features = Features({
    'pixel_values': Array3D(shape=(3, 1280, 960), dtype="float32"),
    'labels': Sequence(feature=Value(dtype='int64')),
    'target_sequence': Value('string'),
    'image_folder': Value('string')
    })

train_images = "/content/drive/MyDrive/Unstructured Project/Unstructured_Jay/code/all_images/train_images"
eval_images = "/content/drive/MyDrive/Unstructured Project/Unstructured_Jay/code/all_images/eval_images"
test_images = "/content/drive/MyDrive/Unstructured Project/Unstructured_Jay/code/all_images/test_images"
train_xml = "/content/drive/MyDrive/Unstructured Project/Unstructured_Jay/code/all_images/train_xml"
eval_xml = "/content/drive/MyDrive/Unstructured Project/Unstructured_Jay/code/all_images/eval_xml"
test_xml = "/content/drive/MyDrive/Unstructured Project/Unstructured_Jay/code/all_images/test_xml"

train_input_file = '/content/drive/MyDrive/Unstructured Project/Unstructured_Jay/code/all_images/intermediate/train_data.json'
eval_input_file = '/content/drive/MyDrive/Unstructured Project/Unstructured_Jay/code/all_images/intermediate/eval_data.json'
test_input_file = '/content/drive/MyDrive/Unstructured Project/Unstructured_Jay/code/all_images/intermediate/test_data.json'
train_output_file = '/content/drive/MyDrive/Unstructured Project/Unstructured_Jay/code/all_images/intermediate/train_processed.json'
eval_output_file = '/content/drive/MyDrive/Unstructured Project/Unstructured_Jay/code/all_images/intermediate/eval_processed.json'
test_output_file = '/content/drive/MyDrive/Unstructured Project/Unstructured_Jay/code/all_images/intermediate/test_processed.json'

#XML -> JSON Conversion
train_json_output = create_json(train_images, train_xml)
eval_json_output = create_json(eval_images, eval_xml)
test_json_output = create_json(test_images, test_xml)

#Saves JSON file
with open(train_input_file, "w") as json_file:
    json.dump(train_json_output, json_file, indent=4)

with open(eval_input_file, "w") as json_file:
    json.dump(eval_json_output, json_file, indent=4)

new_special_tokens = [] # new tokens for doc structure
task_start_token = "<s>"
eos_token = "</s>"

#JSON -> Donut readable format
process_json(train_input_file, train_output_file)
process_json(eval_input_file, eval_output_file)
process_json(test_input_file, test_output_file)
print(f"New special tokens: {new_special_tokens + [task_start_token] + [eos_token]}")

processor.tokenizer.add_special_tokens({"additional_special_tokens": new_special_tokens + [task_start_token] + [eos_token]}) # add new special tokens to tokenizer
processor.image_processor.size =  [960, 1280] #reduced for faster training and inference
processor.image_processor.do_align_long_axis = False

# Model, training parameters, and paths
hf_repository_id = "JayJai/donut_gpu"  # Your Hugging Face repository ID, this is a private repo

new_emb = model.decoder.resize_token_embeddings(len(processor.tokenizer))
print(f"New embedding size: {new_emb}")
model.config.encoder.image_size = [960, 1280]
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s>'])[0]

#Train, Eval, Test dataset creation
train_data = Dataset.from_generator(
    lambda: data_generator(train_output_file, 0, 2000, train_images),
    features=your_dataset_features,
    writer_batch_size= 300
)

eval_data = Dataset.from_generator(
    lambda: data_generator(eval_output_file, 0, 300, eval_images),
    features=your_dataset_features,
    writer_batch_size= 150
)

test_data = Dataset.from_generator(
    lambda: data_generator(test_output_file, 0, 1000, test_images),
    features=your_dataset_features,
    writer_batch_size=5,
    keep_in_memory=False
)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO']= "0.0"
learning_rate= 2e-05

# Training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir=hf_repository_id,
    num_train_epochs= 8, #segments are split into quarter so effectively 2 epochs.
    learning_rate=learning_rate,
    per_device_train_batch_size= 4,
    weight_decay=0.01,
    fp16= False,
    logging_steps=1,
    warmup_steps=0,
    save_total_limit=2,
    eval_strategy="epoch",
    save_strategy="epoch",
    predict_with_generate=True,
    report_to="tensorboard",
    push_to_hub=True,
    hub_strategy="every_save",
    hub_model_id=hf_repository_id,
    hub_token=HfFolder.get_token(),
)
model.config.decoder.max_length = len(max(train_data["labels"], key=len))
trainer = Seq2SeqTrainer(
  model=model,
  args=training_args,
  train_dataset=train_data,
  eval_dataset=eval_data,
  processing_class=processor)
trainer.train()
trainer.save_model(hf_repository_id)
processor.save_pretrained('JayJai/donut_gpu')
trainer.create_model_card()
trainer.push_to_hub()

#plotting loss to see if curve has flattened
training_losses = [log["loss"] for log in trainer.state.log_history if "loss" in log]
import matplotlib.pyplot as plt
plt.plot(training_losses)
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.show()

Epoch,Training Loss,Validation Loss
1,0.3731,0.703171
2,0.4319,0.677876
3,0.344,0.66064
4,1.3035,0.658374


# Test Donut

In [None]:
!huggingface-cli login

In [None]:
import transformers
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel, AutoProcessor, AutoTokenizer, AutoModelForImageTextToText
import torch
import random
from datasets import load_from_disk
import re
from tqdm import tqdm
import numpy as np
from huggingface_hub import HfFolder, notebook_login
from google.colab import userdata
from huggingface_hub import notebook_login
from google.colab import userdata
import textwrap
import json
import re
from collections import defaultdict

# hide logs
transformers.logging.disable_default_handler()
hf_repository_id = 'JayJai/donut_gpu' #private repo

# Load model from Hugging Face
tokenizer = AutoTokenizer.from_pretrained(hf_repository_id)
model = AutoModelForImageTextToText.from_pretrained(hf_repository_id)

# Move model to GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

def run_prediction(test_sample, model=model, processor=processor):
    """
    Runs inference on a single test sample using the specified model and processor.

    Args:
        test_sample (dict): A dictionary containing the test sample data, including "pixel_values" and "target_sequence".
        model (transformers.PreTrainedModel): The model to use for inference.
        processor (transformers.PreTrainedProcessor): The processor associated with the model.

    Returns:
        tuple: A tuple containing the generated sequence (str) and the ground truth target sequence (str).
    """
    # prepare inputs
    pixel_values = torch.tensor(test_sample["pixel_values"]).unsqueeze(0)
    task_prompt = "<s>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids

    #run inference
    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        do_sample=True,
        num_beams=3,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    # process output
    seq = processor.batch_decode(outputs.sequences)[0]

    seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    print(seq)

    # load reference target
    ground_truth = test_sample["target_sequence"]
    return seq, ground_truth

all_accs = []
# accuracy testing
n=0
for sample in test_data:
  n+=1
  print(n)
  prediction, target = run_prediction(sample)
  for t in zip(prediction, target):
    try:
      accuracies = final_accuracy(target, prediction)
      if accuracies:
            all_accs.append(accuracies)
    except Exception as e:
      print(f"Skip: {e}")
  average_accuracies = avg_accuracy(all_accs)

overall_text_accuracies = average_accuracies["overall_text_extraction"]["percentage"]

#Calculating different accuracy metrics
tag_categorization_accuracies = average_accuracies["tag_categorization"]["percentage"]

paragraph_text_accuracies = average_accuracies["text_extraction_by_tag"]["paragraph"][
    "percentage"
]
image_text_accuracies = average_accuracies["text_extraction_by_tag"]["image"][
    "percentage"
]
title_text_accuracies = average_accuracies["text_extraction_by_tag"]["title"][
    "percentage"
]
table_text_accuracies = average_accuracies["text_extraction_by_tag"]["table"][
    "percentage"
]
page_header_text_accuracies = average_accuracies["text_extraction_by_tag"][
    "page_header"
]["percentage"]
subheading_text_accuracies = average_accuracies["text_extraction_by_tag"][
    "subheading"
]["percentage"]
code_snippet_text_accuracies = average_accuracies["text_extraction_by_tag"][
    "code_snippet"
]["percentage"]
page_footer_text_accuracies = average_accuracies["text_extraction_by_tag"][
    "page_footer"
]["percentage"]

tag_specific_accuracies = average_accuracies["tag_categorization"]["tag_accuracy_dict"]
text_extraction_by_tag = average_accuracies["text_extraction_by_tag"]

#printing all accuracies
print(f"Average Overall Text Extraction Accuracy: {overall_text_accuracies:.4f}")
print(f"Average Tag Categorization Accuracy: {tag_categorization_accuracies:.4f}")

print("Tag-Specific Text Extraction Accuracies:")
for tag, accuracy in text_extraction_by_tag.items():
  for subtag, subaccuracy in accuracy.items():
    print(f"  {tag}: {subaccuracy:.4f}")

print("Tag Extraction Accuracies:")
for tag, accuracy in tag_specific_accuracies.items():
    print(f"  {tag}: {accuracy:.4f}")

# Test Claude

In [None]:
from google.colab import userdata
from anthropic import Anthropic
import torch
import numpy as np
from PIL import Image
import json
import io
import base64
import regex as re

#Claude's prompt
prompting = """You are a document structure expert. You identify the document's structure and delineate paragaphs, not group them.
You only output in the the same exact format as the examples without any extra words. Always end with <s>.
These are the tags you can use:
paragraph, image, title, table, page_header, subheading, code_snippet, page_footer
<example> <s><page_header>J. Theor. Appl. Electron. Commer. Res. 2021, 16</page_header><sep/><paragraph>been accompanied by the rise of the Internet and the rise of online consumer behavior\u2014B2C\nand C2B research. For example, in the case of B2C, it is essentially a research-led analysis\nof the role of a company\u2019s existing value model in creating its value for consumers [52]. For\nexample, Agnihotri R [53] argues that social media has changed the way buyers and sellers\ninteract with each other. The use of social media and information dissemination behaviors\nby companies acting on consumers can enhance customer satisfaction and deliver value. In\ncontrast, in C2B research, the consumer is used as the research lead, dissecting the impact\nof purchase motivation, behavior, and feedback on business activities in the above diagram.\nFor example, Cortez RM [54] argues that value creation occurs mainly in increasing a\nfirm\u2019s customer assets, using big data for marketing precision, and optimizing the industry\nenvironment and platform ecosystem, which are determined mainly by the people the\nfirm serves\u2014the consumers. Online consumer behavior and the concept of production\nand consumption have brought about a change in the structure of the value chain and\ndisruption in the way value is created for traditional enterprises, resulting in multiple\nresearch paths for B2B, C2C, and B2C and C2B.</paragraph><sep/><image>information\n\nservicejquality\nvirtual communities\n- moderating role\n\\ value\nbehaviorahjintentions\n\nbrand eaperience customer agement\nm \u201cey nono oy mee\n\n+ \\ gt ded@opmerit \u00b0\nPartidpation\n\nee\n\n  \n     \n  \n  \n \n\nS nh\ncustomer @grticipation \u00a9 % Footer i Wh \u2018on $0\u00a2 edia\n\nr 3 ig brand cammunity validation\nstructural eq@etion madels\n\nanager ere mii\n\nfinanciagservices _\n\n \n\ncoproguction qualitative research .\nP \u00b0 wt \u2018value j @ community\nsharing @conomy , ;\n\u00a9 customer valu? \u2014 , online\ntues service agginant logic\nqu@bty value @eation value c@greation ! \u2014\u2014persfigctive\n\n\u201cservice\nf2M@V0'berceptions\n\u00a9 consumption...\ninformationstechnology\n\nadoption commerce</image><sep/><paragraph>Figure 6. Keyword analysis for the last 5 years (2015-2019) in the area of online consumer behavior\u2014value co-creation.</paragraph><sep/><paragraph>Social media research</paragraph><sep/><paragraph>With the popularity of social media, a revolution in media and interaction is sweeping\nconsumers and markets. When online consumers visit Facebook and Little Red Book,\npost tweets and microblogs, or browse, like, comment and retweet on social media, social\nmedia marketing silently permeates these online consumer behaviors [55]. In the diagram,\nthe terms \u201cFacebook,\u201d \u201cTwitter,\u201d \u201cSocial networking,\u201d \u201cSocial media,\u201d \u201cRelationship,\u201d\nand \u201cElectronic word of mouth\u201d allude to the academic understanding of this change.\nThe appearance of these keywords alludes to the academic attention and interest in this\n\nchange [56].</paragraph></s>
</example>
<example> <s><page_header>XI IRCSA CONFERENCE -- PROCEEDINGS</page_header><sep/><paragraph>0.02) per litre. This works out as 40 cents ($ 0.40) per family per day (Table 3: Costing\n\nexercise).</paragraph><sep/><paragraph>Similarly the cost of setting up a rooftop rainwater harvesting, storage and purification system\nusing the RainPC for a community of 14 families (70 people) has been worked out.\nConsidering the useful lifespan of the RainPC as four years, the cost of providing safe\ndrinking water is 1.4 cents ($ 0.014; Table 3: Costing exercise). This works out as 28 cents\nper family per day ($ 0.28).</paragraph><sep/><paragraph>Conclusions:</paragraph><sep/><paragraph>Rain has tremendous potential to meet drinking water requirements. Its use to meet drinking\nwater requirements can no longer be ignored. There is a need to rethink about water supply\nand rain and use it for potable purposes. If the quality is poor, the rainwater can be treated\nand made potable by technologies like the RainPC. The Water Supply and Sanitation\nCollaborative Council (WSSCC) also believes that \u201cwater supply services are more\nsuccessful when people feel they are responsible for, and benefit from, them. Such\n\napproaches also lend themselves to \u2018 scaling up.\u2019 In this context modern versions of old\nstrategies such as household rainwater harvesting have enormous potential. The private\nsector in the form of local artisans, masons and small-scale manufacturers can develop and\nmarket low-cost technologies. In this way, better sanitation and water supply also contributes\n\nto the local economy.\u201d</paragraph><sep/><paragraph>If indeed, increasing access to household and community safe drinking water has to be\n\nachieved, the inexpensive and rapid solution lies in rain.</paragraph></s>
</example>
"""

question = "Identify document structure:"

def anthropic(image):
    """
    Sends an image to the Anthropic API and returns the text response.

    Args:
        i (PIL.Image): The input image.

    Returns:
        str: The text response from the Anthropic API.
    """
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")  # Save as PNG to the buffer
    img_str = base64.b64encode(buffered.getvalue()).decode()
    client = Anthropic(api_key=userdata.get('claudeapi'))
    message = client.messages.create(
        model="claude-3-5-sonnet-20241022",
        max_tokens=1024,
        messages=[
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type":  "image/png",
                            "data": img_str,
                        },
                    },
                    {
                        "type": "text",
                        "text": prompting
                    }
                ],
            }
        ],
    )
    return message.content[0].text

def pixels_to_image(tensor):
  """
  Converts a PyTorch tensor representing an image to a PIL Image object.

  Args:
      tensor (torch.Tensor): The input tensor with pixel values in the range [-1, 1].

  Returns:
      PIL.Image.Image: The corresponding PIL Image.
  """
  tensor = tensor.clamp(-1, 1)
  tensor = (tensor + 1) / 2 * 255
  tensor = tensor.permute(1, 2, 0)
  tensor = tensor.to(torch.uint8)
  image = Image.fromarray(tensor.numpy())
  return image

processed_data = []
# Create List of {Claude prediction, Target Sequence} (JSON)
for item in test_data:
    if isinstance(item, dict) and 'pixel_values' in item and 'target_sequence' in item:
      try:
        pixels = item['pixel_values']
        processed_item = anthropic(pixels_to_image(torch.tensor(pixels)))
        processed_data.append({"prediction": processed_item, "target_sequence": item["target_sequence"]})
      except Exception as e:
          print(f"Skip: {e}")
    else:
        print(f"Warning: Skipping invalid item: {item}")

In [None]:
import json
import re
from collections import defaultdict
import re

output_list = []
all_accs = []

for item in processed_data:
  try:
    accuracies = final_accuracy(item['target_sequence'], item['prediction'])
    if accuracies:
        all_accs.append(accuracies)

  except Exception as e:
    print(f"Skip: {e}")
average_accuracies = avg_accuracy(all_accs)

#print all accuracies
overall_text_accuracies = average_accuracies['overall_text_extraction']['percentage']
tag_categorization_accuracies = average_accuracies['tag_categorization']['percentage']
paragraph_text_accuracies = average_accuracies['text_extraction_by_tag']['paragraph']['percentage']
image_text_accuracies = average_accuracies['text_extraction_by_tag']['image']['percentage']
title_text_accuracies = average_accuracies['text_extraction_by_tag']['title']['percentage']
table_text_accuracies = average_accuracies['text_extraction_by_tag']['table']['percentage']
page_header_text_accuracies = average_accuracies['text_extraction_by_tag']['page_header']['percentage']
subheading_text_accuracies = average_accuracies['text_extraction_by_tag']['subheading']['percentage']
code_snippet_text_accuracies = average_accuracies['text_extraction_by_tag']['code_snippet']['percentage']
page_footer_text_accuracies = average_accuracies['text_extraction_by_tag']['page_footer']['percentage']

print(f"Average Overall Text Extraction Accuracy: {overall_text_accuracies:.4f}")
print(f"Average Tag Categorization Accuracy: {tag_categorization_accuracies:.4f}")
print(f"Average Paragraph Extraction Accuracy: {paragraph_text_accuracies:.4f}")
print(f"Average Image Extraction Accuracy: {image_text_accuracies:.4f}")
print(f"Average Title Extraction Accuracy: {title_text_accuracies:.4f}")
print(f"Average Table Extraction Accuracy: {table_text_accuracies:.4f}")
print(f"Average Page Header Extraction Accuracy: {page_header_text_accuracies:.4f}")
print(f"Average Subheading Extraction Accuracy: {subheading_text_accuracies:.4f}")
print(f"Average Code Snippet Extraction Accuracy: {code_snippet_text_accuracies:.4f}")
print(f"Average Page Footer Extraction Accuracy: {page_footer_text_accuracies:.4f}")

Average Overall Text Extraction Accuracy: 0.3720
Average Tag Categorization Accuracy: 0.3843
Average Paragraph Extraction Accuracy: 0.3933
Average Image Extraction Accuracy: 0.0335
Average Title Extraction Accuracy: 0.3368
Average Table Extraction Accuracy: 0.1870
Average Page Header Extraction Accuracy: 0.3789
Average Subheading Extraction Accuracy: 0.3684
Average Code Snippet Extraction Accuracy: 0.5867
Average Page Footer Extraction Accuracy: 0.2880
