<a href="https://colab.research.google.com/github/yongsun-yoon/deep-learning-paper-implementation/blob/main/03-natural-language-process/MarkupLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MarkupLM

## 0. Info

## paper
* title: MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding
* author: Junlong Li et al.
* url: https://arxiv.org/abs/2110.08518

## feats
* dataset: korquad2

## refs
* https://huggingface.co/docs/transformers/model_doc/markuplm
* https://github.com/microsoft/unilm/tree/13b1fd1cb6828004e2cea81c9f93ababfe024922/markuplm

## 1. Setup

In [None]:
import re
import bs4
import html
import wandb
import easydict
import numpy as np
from lxml import etree
from bs4 import BeautifulSoup
from einops import rearrange
from typing import List
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from datasets import load_dataset
from transformers import get_scheduler, BatchEncoding
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForMaskedLM
from transformers.models.roberta.modeling_roberta import RobertaClassificationHead

In [None]:
cfg = easydict.EasyDict(
    device = 'cuda:0',
    num_training_steps = 10000,
    num_warmup_steps = 500,
)

## 2. Utils

In [None]:
TAGS_DICT = {
    'a': 0,
    'abbr': 1,
    'acronym': 2,
    'address': 3,
    'altGlyph': 4,
    'altGlyphDef': 5,
    'altGlyphItem': 6,
    'animate': 7,
    'animateColor': 8,
    'animateMotion': 9,
    'animateTransform': 10,
    'applet': 11,
    'area': 12,
    'article': 13,
    'aside': 14,
    'audio': 15,
    'b': 16,
    'base': 17,
    'basefont': 18,
    'bdi': 19,
    'bdo': 20,
    'bgsound': 21,
    'big': 22,
    'blink': 23,
    'blockquote': 24,
    'body': 25,
    'br': 26,
    'button': 27,
    'canvas': 28,
    'caption': 29,
    'center': 30,
    'circle': 31,
    'cite': 32,
    'clipPath': 33,
    'code': 34,
    'col': 35,
    'colgroup': 36,
    'color-profile': 37,
    'content': 38,
    'cursor': 39,
    'data': 40,
    'datalist': 41,
    'dd': 42,
    'defs': 43,
    'del': 44,
    'desc': 45,
    'details': 46,
    'dfn': 47,
    'dialog': 48,
    'dir': 49,
    'div': 50,
    'dl': 51,
    'dt': 52,
    'ellipse': 53,
    'em': 54,
    'embed': 55,
    'feBlend': 56,
    'feColorMatrix': 57,
    'feComponentTransfer': 58,
    'feComposite': 59,
    'feConvolveMatrix': 60,
    'feDiffuseLighting': 61,
    'feDisplacementMap': 62,
    'feDistantLight': 63,
    'feFlood': 64,
    'feFuncA': 65,
    'feFuncB': 66,
    'feFuncG': 67,
    'feFuncR': 68,
    'feGaussianBlur': 69,
    'feImage': 70,
    'feMerge': 71,
    'feMergeNode': 72,
    'feMorphology': 73,
    'feOffset': 74,
    'fePointLight': 75,
    'feSpecularLighting': 76,
    'feSpotLight': 77,
    'feTile': 78,
    'feTurbulence': 79,
    'fieldset': 80,
    'figcaption': 81,
    'figure': 82,
    'filter': 83,
    'font': 89,
    'font-face': 88,
    'font-face-format': 84,
    'font-face-name': 85,
    'font-face-src': 86,
    'font-face-uri': 87,
    'footer': 90,
    'foreignObject': 91,
    'form': 92,
    'frame': 93,
    'frameset': 94,
    'g': 95,
    'glyph': 96,
    'glyphRef': 97,
    'h1': 98,
    'h2': 99,
    'h3': 100,
    'h4': 101,
    'h5': 102,
    'h6': 103,
    'head': 104,
    'header': 105,
    'hgroup': 106,
    'hkern': 107,
    'hr': 108,
    'html': 109,
    'i': 110,
    'iframe': 111,
    'image': 112,
    'img': 113,
    'input': 114,
    'ins': 115,
    'kbd': 116,
    'keygen': 117,
    'label': 118,
    'legend': 119,
    'li': 120,
    'line': 121,
    'linearGradient': 122,
    'link': 123,
    'main': 124,
    'map': 125,
    'mark': 126,
    'marker': 127,
    'marquee': 128,
    'mask': 129,
    'math': 130,
    'menu': 131,
    'menuitem': 132,
    'meta': 133,
    'metadata': 134,
    'meter': 135,
    'missing-glyph': 136,
    'mpath': 137,
    'nav': 138,
    'nobr': 139,
    'noembed': 140,
    'noframes': 141,
    'noscript': 142,
    'object': 143,
    'ol': 144,
    'optgroup': 145,
    'option': 146,
    'output': 147,
    'p': 148,
    'param': 149,
    'path': 150,
    'pattern': 151,
    'picture': 152,
    'plaintext': 153,
    'polygon': 154,
    'polyline': 155,
    'portal': 156,
    'pre': 157,
    'progress': 158,
    'q': 159,
    'radialGradient': 160,
    'rb': 161,
    'rect': 162,
    'rp': 163,
    'rt': 164,
    'rtc': 165,
    'ruby': 166,
    's': 167,
    'samp': 168,
    'script': 169,
    'section': 170,
    'select': 171,
    'set': 172,
    'shadow': 173,
    'slot': 174,
    'small': 175,
    'source': 176,
    'spacer': 177,
    'span': 178,
    'stop': 179,
    'strike': 180,
    'strong': 181,
    'style': 182,
    'sub': 183,
    'summary': 184,
    'sup': 185,
    'svg': 186,
    'switch': 187,
    'symbol': 188,
    'table': 189,
    'tbody': 190,
    'td': 191,
    'template': 192,
    'text': 193,
    'textPath': 194,
    'textarea': 195,
    'tfoot': 196,
    'th': 197,
    'thead': 198,
    'time': 199,
    'title': 200,
    'tr': 201,
    'track': 202,
    'tref': 203,
    'tspan': 204,
    'tt': 205,
    'u': 206,
    'ul': 207,
    'use': 208,
    'var': 209,
    'video': 210,
    'view': 211,
    'vkern': 212,
    'wbr': 213,
    'xmp': 214
}

In [None]:
def xpath_soup(element):
    xpath_tags = []
    xpath_subscripts = []
    child = element if element.name else element.parent
    for parent in child.parents:  # type: bs4.element.Tag
        siblings = parent.find_all(child.name, recursive=False)
        xpath_tags.append(child.name)
        xpath_subscripts.append(
            0 if 1 == len(siblings) else next(i for i, s in enumerate(siblings, 1) if s is child)
        )
        child = parent
    xpath_tags.reverse()
    xpath_subscripts.reverse()
    return xpath_tags, xpath_subscripts


def get_three_from_single(html_string):
    html_code = BeautifulSoup(html_string, "html.parser")

    all_doc_strings = []
    string2xtag_seq = []
    string2xsubs_seq = []

    for element in html_code.descendants:
        if type(element) == bs4.element.NavigableString:
            if type(element.parent) != bs4.element.Tag:
                continue

            text_in_this_tag = html.unescape(element).strip()
            if not text_in_this_tag:
                continue

            all_doc_strings.append(text_in_this_tag)

            xpath_tags, xpath_subscripts = xpath_soup(element)
            string2xtag_seq.append(xpath_tags)
            string2xsubs_seq.append(xpath_subscripts)

    if len(all_doc_strings) != len(string2xtag_seq):
        raise ValueError("Number of doc strings and xtags does not correspond")
    if len(all_doc_strings) != len(string2xsubs_seq):
        raise ValueError("Number of doc strings and xsubs does not correspond")

    return all_doc_strings, string2xtag_seq, string2xsubs_seq


def construct_xpath(xpath_tags, xpath_subscripts):
    xpath = ""
    for tagname, subs in zip(xpath_tags, xpath_subscripts):
        xpath += f"/{tagname}"
        if subs != 0:
            xpath += f"[{subs}]"
    return xpath


def extract_features(html_string):
    nodes, string2xtag_seq, string2xsubs_seq = get_three_from_single(html_string)
    xpaths = []
    for node, tag_list, sub_list in zip(nodes, string2xtag_seq, string2xsubs_seq):
        xpath_string = construct_xpath(tag_list, sub_list)
        xpaths.append(xpath_string)
    
    return nodes, xpaths


class MarkupLMTokenizer(object):
    tags_dict = TAGS_DICT
    unk_tag_id = len(tags_dict)
    pad_tag_id = unk_tag_id + 1
    max_depth = 50
    max_width = 1000
    pad_width = 1001
    
    def __init__(self, base_tokenizer):
        self.base_tokenizer = base_tokenizer
        self.pad_tags, self.pad_subs = self.get_xpath_seq('')
        
        
    def get_xpath_seq(self, xpath):
        """
        Given the xpath expression of one particular node (like "/html/body/div/li[1]/div/span[2]"), return a list of
        tag IDs and corresponding subscripts, taking into account max depth.
        """
        xpath_tags_list = []
        xpath_subs_list = []

        xpath_units = xpath.split("/")
        for unit in xpath_units:
            if not unit.strip():
                continue
            name_subs = unit.strip().split("[")
            tag_name = name_subs[0]
            sub = 0 if len(name_subs) == 1 else int(name_subs[1][:-1])
            xpath_tags_list.append(self.tags_dict.get(tag_name, self.unk_tag_id))
            xpath_subs_list.append(min(self.max_width, sub))

        xpath_tags_list = xpath_tags_list[: self.max_depth]
        xpath_subs_list = xpath_tags_list[: self.max_depth]
        xpath_tags_list += [self.pad_tag_id] * (self.max_depth - len(xpath_tags_list))
        xpath_subs_list += [self.pad_width] * (self.max_depth - len(xpath_subs_list))
        return xpath_tags_list, xpath_subs_list
        
    def encode_nodes(self, nodes: List[int], xpaths: List[int], truncation=True, max_length=512, padding='max_length'):
        encodings = self.base_tokenizer(nodes, is_split_into_words=True, add_special_tokens=True, truncation=truncation, max_length=max_length, padding=padding)
        input_ids = encodings['input_ids']
        attention_mask = encodings['attention_mask']
        
        cache = {}
        xpath_tags_seq = []
        xpath_subs_seq = []
        for ti in range(len(input_ids)):
            wi = encodings.token_to_word(ti)
            if wi is None:
                tags, subs = self.pad_tags, self.pad_subs
            else:
                if wi in cache:
                    tags, subs = cache[wi]
                else:
                    tags, subs = self.get_xpath_seq(xpaths[wi])
                    cache[wi] = (tags, subs)
                    
            xpath_tags_seq.append(tags)
            xpath_subs_seq.append(subs)
            
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'xpath_tags_seq': xpath_tags_seq,
            'xpath_subs_seq': xpath_subs_seq
        }
    
    
    def batch_encode_nodes(self, batch_nodes: List[List[int]], batch_xpaths: List[List[int]], max_length=512, padding='max_length'):
        batch = [
            self.encode_nodes(nodes, xpaths, max_length=max_length, padding=padding) for nodes, xpaths in zip(batch_nodes, batch_xpaths)
        ]
        keys = batch[0].keys()
        batch = {k: torch.LongTensor([b[k] for b in batch]) for k in keys}
        return BatchEncoding(batch)

## 3. Pretrain

### 3.1. Data

In [None]:
class Dataset(torch.utils.data.Dataset):    
    max_context_length = 10000
    
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        
        
    def __len__(self):
        return len(self.data)
    
    
    def find_title(self, context):
        title_span = re.search('<title>[\s\S]+</title>', context)
        if title_span is not None:
            return context[title_span.start():title_span.end()]
        return ''
    
    def __getitem__(self, idx):
        context = self.data.iloc[idx]['context']
        if len(context) > self.max_context_length:
            context_start = np.random.randint(0, len(context)-self.max_context_length)
            context = context[context_start:context_start+self.max_context_length]
            
        title = self.find_title(context)
        if not title:
            title_labels = -100
        
        elif np.random.rand() < 0.5:
            title_labels = 1
        
        else:
            other_idx = idx
            while other_idx == idx:
                other_idx = np.random.randint(len(self))
            other_context = self.data.iloc[other_idx]['context']
            other_title = self.find_title(other_context)
            context = context.replace(title, other_title)
            title_labels = 0
        
        nodes, xpaths = extract_features(context)
        encoding = self.tokenizer.encode_nodes(nodes, xpaths)
        
        
        mlm_labels = encoding['input_ids']
        special_tokens_mask = self.tokenizer.base_tokenizer.get_special_tokens_mask(mlm_labels, already_has_special_tokens=True)
        input_ids = [l if s or np.random.rand() > 0.15 else self.tokenizer.base_tokenizer.mask_token_id for l,s in zip(mlm_labels, special_tokens_mask)]
        mlm_labels = [l if i==base_tokenizer.mask_token_id else -100 for l,i in zip(mlm_labels, input_ids)]
        
        encoding['input_ids'] = input_ids
        encoding['title_labels'] = title_labels
        encoding['mlm_labels'] = mlm_labels
        return encoding
    
    
def collate_fn(batch):
    encoded = {}
    keys = batch[0].keys()
    for k in keys:
        v = [torch.tensor(b[k]) for b in batch]
        encoded[k] = torch.stack(v)
    return BatchEncoding(encoded)

### 3.2. Model

In [None]:
class XPathEmbeddings(nn.Module):
    def __init__(
        self, 
        max_depth=50, 
        xpath_unit_hidden_size=32,
        hidden_size=768,
        hidden_dropout_prob = 0.1,
        max_xpath_tag_unit_embeddings = 256,
        max_xpath_subs_unit_embeddings = 1024,
    ):
        super(XPathEmbeddings, self).__init__()
        self.max_depth = max_depth

        self.xpath_unitseq2_embeddings = nn.Linear(xpath_unit_hidden_size * self.max_depth, hidden_size)

        self.dropout = nn.Dropout(hidden_dropout_prob)

        self.activation = nn.ReLU()
        self.xpath_unitseq2_inner = nn.Linear(xpath_unit_hidden_size * self.max_depth, 4 * hidden_size)
        self.inner2emb = nn.Linear(4 * hidden_size, hidden_size)

        self.xpath_tag_sub_embeddings = nn.ModuleList(
            [
                nn.Embedding(max_xpath_tag_unit_embeddings, xpath_unit_hidden_size)
                for _ in range(self.max_depth)
            ]
        )

        self.xpath_subs_sub_embeddings = nn.ModuleList(
            [
                nn.Embedding(max_xpath_subs_unit_embeddings, xpath_unit_hidden_size)
                for _ in range(self.max_depth)
            ]
        )

    def forward(self, xpath_tags_seq=None, xpath_subs_seq=None):
        xpath_tags_embeddings = []
        xpath_subs_embeddings = []

        for i in range(self.max_depth):
            xpath_tags_embeddings.append(self.xpath_tag_sub_embeddings[i](xpath_tags_seq[:, :, i]))
            xpath_subs_embeddings.append(self.xpath_subs_sub_embeddings[i](xpath_subs_seq[:, :, i]))

        xpath_tags_embeddings = torch.cat(xpath_tags_embeddings, dim=-1)
        xpath_subs_embeddings = torch.cat(xpath_subs_embeddings, dim=-1)

        xpath_embeddings = xpath_tags_embeddings + xpath_subs_embeddings

        xpath_embeddings = self.inner2emb(self.dropout(self.activation(self.xpath_unitseq2_inner(xpath_embeddings))))

        return xpath_embeddings

In [None]:
class Model(nn.Module):
    def __init__(self, backbone_name):
        super().__init__()
        self.config = AutoConfig.from_pretrained(backbone_name)
        self.backbone = AutoModelForMaskedLM.from_pretrained(backbone_name)
        self.xpath_embedding = XPathEmbeddings()
        self.title_head = RobertaClassificationHead(self.config)
    
    def forward(self, input_ids, attention_mask, xpath_tags_seq, xpath_subs_seq):
        xpath_embeds = self.xpath_embedding(xpath_tags_seq, xpath_subs_seq)
        token_embeds = self.backbone.roberta.embeddings.word_embeddings(input_ids)
        embeds = token_embeds + xpath_embeds
        outputs = self.backbone(inputs_embeds=embeds, attention_mask=attention_mask, output_hidden_states=True)

        mlm_outputs = outputs.logits
        last_hidden_state = outputs.hidden_states[-1]
        title_outputs = model.title_head(last_hidden_state)
        return mlm_outputs, title_outputs

### 3.3. Pretrain

In [None]:
data = load_dataset('KETI-AIR/korquad', 'v2.1')
train_data = data['train'].to_pandas()
train_data = train_data.drop_duplicates('context')

In [None]:
base_tokenizer = AutoTokenizer.from_pretrained('klue/roberta-base')
tokenizer = MarkupLMTokenizer(base_tokenizer)

In [None]:
dataset = Dataset(train_data, tokenizer)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [None]:
model = Model('klue/roberta-base')
_ = model.train().to(cfg.device)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = get_scheduler('cosine', optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=cfg.num_training_steps)

In [None]:
dataiter = iter(dataloader)

pbar = tqdm(range(1, cfg.num_training_steps+1))
for st in pbar:
    try:
        batch = next(dataiter)
    except StopIteration:
        dataiter = iter(dataloader)
        batch = next(dataiter)
        
    batch = batch.to(cfg.device)
    
    mlm_outputs, title_outputs = model(batch.input_ids, batch.attention_mask, batch.xpath_tags_seq, batch.xpath_subs_seq)
    mlm_loss = F.cross_entropy(rearrange(mlm_outputs, 'B S V -> (B S) V'), rearrange(batch.mlm_labels, 'B S -> (B S)'))
    title_loss = F.cross_entropy(title_outputs, batch.title_labels)
    loss = mlm_loss + title_loss
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()
    
    log = {'loss': loss.item(), 'mlm': mlm_loss.item(), 'title': title_loss.item()}
    pbar.set_postfix(log)
    
    if st % 1000 == 0:
        torch.save(model.state_dict(), 'markuplm.pt')