In [1]:
!pip install -q transformers datasets sentencepiece

In [2]:
!pip install -q pytorch-lightning wandb torchvision

In [22]:
from transformers import VisionEncoderDecoderConfig

image_size = [1280, 960]
max_length = 512

config = VisionEncoderDecoderConfig.from_pretrained("naver-clova-ix/donut-base")
config.encoder.image_size = image_size
config.decoder.max_length = max_length

In [23]:
from transformers import DonutProcessor, VisionEncoderDecoderModel

processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base", config=config)

In [24]:
import re

def preprocess_html(html):
    # Replace spaces within style attributes with a non-breaking space
    html = re.sub(r'(style="[^"]+")', lambda x: x.group(1).replace(' ', '\xa0'), html)
    return html

In [25]:
import json
import random
from typing import Any, List, Tuple
from PIL import Image
from torchvision import transforms

import torch
from torch.utils.data import Dataset

added_tokens = []

class HtmlTablesDataset(Dataset):
    def __init__(
        self,
        json_file: str,
        max_length: int,
        split: str = "train",
        ignore_id: int = -100,
        task_start_token: str = "",
        prompt_end_token: str = None,
        sort_json_key: bool = True,
    ):
        super().__init__()
        self.max_length = max_length
        self.split = split
        self.ignore_id = ignore_id
        self.task_start_token = task_start_token
        self.prompt_end_token = prompt_end_token if prompt_end_token else task_start_token
        self.sort_json_key = sort_json_key

        with open(json_file, 'r') as file:
            # Split data
            data = json.load(file)
            total_len = len(data)
            train_end = int(0.7 * total_len)
            val_end = int(0.85 * total_len)
            
            if self.split == 'train':
                self.data_pairs = data[:train_end]
            elif self.split == 'validation':
                self.data_pairs = data[train_end:val_end]
            elif self.split == 'test':
                self.data_pairs = data[val_end:]
            else:
                raise ValueError("Invalid split name")             
        self.dataset_length = len(self.data_pairs)       

        # Initialize transformations for images
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])

        html_tokens = ['<s_html>', '<table>', '<table style="border-collapse: collapse;">', '<th>'
                       , '<th style="border: 1px solid black;">', '<tr>', '<td>', '</td>'
                       , '<td style="border: 1px solid black;">', '</tr>', '</th>', '</table>, </s_html>']
        self.add_tokens(html_tokens)

        self.gt_token_sequences = []
        for sample in self.data_pairs:
            gt_jsons = sample["html"]
            self.gt_token_sequences.append(self.minify_html(gt_jsons))
       
        if task_start_token or prompt_end_token:
            # Assuming the tokenizer can handle adding tokens if necessary
            self.add_tokens([self.task_start_token, self.prompt_end_token])
            self.prompt_end_token_id = processor.tokenizer.convert_tokens_to_ids(self.prompt_end_token)

    def minify_html(self, html: str):
        # Replace escaped double quotes with regular double quotes
        html = html.replace('\\"', '"')
        # Remove newline characters
        html = html.replace('\n', '')
        # Optionally, remove extra spaces between tags if they exist
        html = ' '.join(html.split())
        return html

    def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
        """
        Convert an ordered JSON object into a token sequence
        """
        if type(obj) == dict:
            if len(obj) == 1 and "text_sequence" in obj:
                return obj["text_sequence"]
            else:
                output = ""
                if sort_json_key:
                    keys = sorted(obj.keys(), reverse=True)
                else:
                    keys = obj.keys()
                for k in keys:
                    if update_special_tokens_for_json_key:
                        self.add_tokens([fr"", fr""])
                    output += (
                        fr""
                        + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
                        + fr""
                    )
                return output
        elif type(obj) == list:
            return r"".join(
                [self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
            )
        else:
            obj = str(obj)
            if f"<{obj}/>" in added_tokens:
                obj = f"<{obj}/>"  # for categorical special tokens
            return obj
    
    def add_tokens(self, list_of_tokens: List[str]):
        """
        Add special tokens to tokenizer and resize the token embeddings of the decoder
        """
        newly_added_num = processor.tokenizer.add_tokens(list_of_tokens)
        if newly_added_num > 0:
            model.decoder.resize_token_embeddings(len(processor.tokenizer))
            added_tokens.extend(list_of_tokens)
    
    def __len__(self) -> int:
        return self.dataset_length

    def reassemble_html_tokens(self, tokens):
        # Implement reassembly logic that was discussed previously
        new_tokens = []
        buffer = ""
        for token in tokens:
            if token.startswith("▁") and buffer:
                new_tokens.append(buffer)
                buffer = token[1:]  # Remove the '▁' for a new token
            else:
                buffer += token.replace("▁", "")  # Remove '▁' and append to the current buffer
        if buffer:
            new_tokens.append(buffer)  # Append the last buffer if any
        return new_tokens

    def __getitem__(self, idx):
        item = self.data_pairs[idx]
        image_path = item['image'].replace("\\", "/")
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)

        target_sequence = random.choice(self.gt_token_sequences[idx])
        html_content = item['html']
        encoded_html = processor.tokenizer(
            html_content,
            return_tensors='pt',
            max_length=self.max_length,
            truncation=True,
            padding='max_length'
        )["input_ids"].squeeze(0)

        labels = encoded_html.clone()
        labels[labels == processor.tokenizer.pad_token_id] = self.ignore_id 
    
        return image, labels, target_sequence

In [26]:
processor.image_processor.size = image_size[::-1]
processor.image_processor.do_align_long_axis = False

tokenizer = processor.tokenizer

# Initialize the dataset

train_dataset = HtmlTablesDataset(
    json_file='./data_pairs.json', 
    max_length=2056,                         
    split="train", 
    task_start_token="", 
    prompt_end_token="",
    ignore_id=-100,
)

val_dataset = HtmlTablesDataset(
    json_file='./data_pairs.json',
    max_length=2056,                         
    split="validation", 
    task_start_token="", 
    prompt_end_token="",
    ignore_id=-100,
)


In [27]:
len(added_tokens)

12

In [28]:
tokenizer

XLMRobertaTokenizerFast(name_or_path='naver-clova-ix/donut-base', vocab_size=57522, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': '<mask>', 'additional_special_tokens': ['<s_iitcdip>', '<s_synthdog>']}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	57521: AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True, special=Tru

In [29]:
print("Original number of tokens:", processor.tokenizer.vocab_size)
print("Number of tokens after adding special tokens:", len(processor.tokenizer))

Original number of tokens: 57522
Number of tokens after adding special tokens: 57537


In [30]:
pixel_values, labels, target_sequence = train_dataset[0]

In [31]:
print(pixel_values.shape)

torch.Size([3, 224, 224])


In [32]:
for id in labels.tolist()[:100]:
  if id != -100:
    print(processor.decode([id]))
  else:
    print(id)

<s>
<table style="border-collapse: collapse;">

<tr>

<th style="border: 1px solid black;">
Tas
k
</th>

<th style="border: 1px solid black;">
As
sign
ed
To
</th>

<th style="border: 1px solid black;">
Du
e
Date
</th>

</tr>

<tr>

<td style="border: 1px solid black;">
Design
Home
page
</td>

<td style="border: 1px solid black;">
Alice
</td>

<td style="border: 1px solid black;">
20
23
-11
-15
</td>

</tr>

<tr>

<td style="border: 1px solid black;">
Develop
Back
end
API
</td>

<td style="border: 1px solid black;">
Bob
</td>

<td style="border: 1px solid black;">
20
23
-11
-20
</td>

</tr>

<tr>

<td style="border: 1px solid black;">
Set
up
Data
base
</td>

<td style="border: 1px solid black;">
Charlie
</td>

<td style="border: 1px solid black;">
20
23
-11
-18
</td>

</tr>

<tr>

<td style="border: 1px solid black;">
Con
duct
User
Test
ing
</td>



In [40]:
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_html>'])[0]

In [42]:
print("Pad token ID:", processor.decode([model.config.pad_token_id]))
print("Decoder start token ID:", processor.decode([model.config.decoder_start_token_id]))

Pad token ID: <pad>
Decoder start token ID: <s_html>


In [16]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)

In [None]:
batch = next(iter(train_dataloader))
pixel_values, labels, target_sequences = batch
print(pixel_values.shape)

In [43]:
print(len(val_dataset))

30


In [44]:
from pathlib import Path
import re
from nltk import edit_distance
import numpy as np
import math

from torch.nn.utils.rnn import pad_sequence
from torch.optim.lr_scheduler import LambdaLR

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only

In [45]:
class DonutModelPLModule(pl.LightningModule):
    def __init__(self, config, processor, model):
        super().__init__()
        self.config = config
        self.processor = processor
        self.model = model

    def training_step(self, batch, batch_idx):
        pixel_values, labels, _ = batch
        
        outputs = self.model(pixel_values, labels=labels)
        loss = outputs.loss
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx, dataset_idx=0):
        pixel_values, labels, answers = batch
        batch_size = pixel_values.shape[0]
        # we feed the prompt to the model
        decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
        
        outputs = self.model.generate(pixel_values,
                                   decoder_input_ids=decoder_input_ids,
                                   max_length=max_length,
                                   early_stopping=True,
                                   pad_token_id=self.processor.tokenizer.pad_token_id,
                                   eos_token_id=self.processor.tokenizer.eos_token_id,
                                   use_cache=True,
                                   num_beams=1,
                                   bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
                                   return_dict_in_generate=True,)
    
        predictions = []
        for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
            seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
            seq = re.sub(r"<.*?>", "", seq, count=1).strip()  # remove first task start token
            predictions.append(seq)

        scores = []
        for pred, answer in zip(predictions, answers):
            pred = re.sub(r"(?:(?<=>) | (?=", "", answer, count=1)
            answer = answer.replace(self.processor.tokenizer.eos_token, "")
            scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))

            if self.config.get("verbose", False) and len(scores) == 1:
                print(f"Prediction: {pred}")
                print(f"    Answer: {answer}")
                print(f" Normed ED: {scores[0]}")

        self.log("val_edit_distance", np.mean(scores))
        
        return scores

    def configure_optimizers(self):
        # you could also add a learning rate scheduler if you want
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
    
        return optimizer

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return val_dataloader

In [46]:
config = {"max_epochs":30,
          "val_check_interval":0.2, # how many times we want to validate during an epoch
          "check_val_every_n_epoch":1,
          "gradient_clip_val":1.0,
          "num_training_samples_per_epoch": 800,
          "lr":3e-5,
          "train_batch_sizes": [8],
          "val_batch_sizes": [1],
          # "seed":2022,
          "num_nodes": 1,
          "warmup_steps": 300, # 800/8*30/10, 10%
          "result_path": "./result",
          "verbose": True,
          }

model_module = DonutModelPLModule(config, processor, model)

In [None]:
!huggingface-cli login