## Import Modules

In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

import os 
import glob
from transformers import MarkupLMProcessor
from transformers import MarkupLMForTokenClassification

import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import utils
# import input_pipeline as ip

import pandas as pd

from tqdm.auto import tqdm
import yaml

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.backends.mps.is_available()

True

In [3]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = 'cpu'
print(f"Using device {device}")

mps_device = torch.device("mps")

Using device cpu


## Define Helper Functions

1. create_raw_dataset: takes the tagged csvs and creates a dict of 
    xpaths, nodes, node_labels

In [4]:
def create_raw_dataset(tagged_csv_path, id2label, label2id, is_train=True):
    """Preprocesses the tagged csvs in the format required by MarkupLM"""

    tagged_df = pd.read_csv(tagged_csv_path)

    # in train mode we expect text and xpaths that are highlighted 
    # by an annotator
    if is_train:
        col_list = ["nodes", "xpaths", "node_labels"]
        
        tagged_df["highlighted_xpaths"] = tagged_df["highlighted_xpaths"].fillna(
            tagged_df["xpaths"]
        )
        tagged_df["highlighted_segmented_text"] = tagged_df[
            "highlighted_segmented_text"
        ].fillna(tagged_df["text"])

        # drop non-ASCII chars
        tagged_df["highlighted_segmented_text"] = (
            tagged_df["highlighted_segmented_text"]
            .str.encode("ascii", errors="ignore")
            .str.decode("ascii")
        )

        # rename columns to match MarkupLM convention
        tagged_df = tagged_df.rename(
            columns={
                "highlighted_xpaths": "xpaths",
                "highlighted_segmented_text": "nodes",
                "tagged_sequence": "node_labels",
            },
        )

        # convert node labels to integer values
        tagged_df["node_labels"] = tagged_df["node_labels"].apply(
            lambda label: label2id[label]
        )
    
    else:
        col_list = ["nodes", "xpaths"]
        
        # rename columns to match MarkupLM convention
        tagged_df = tagged_df.rename(
            columns={
                "xpaths": "xpaths",
                "text": "nodes",
            },
        )     
    
    tagged_output = tagged_df.loc[:, col_list].to_dict(orient="list")

    # convert each key to a list of lists just like the MarkupLM
    # pipeline requires
    for k, v in tagged_output.items():
        tagged_output[k] = [v]

    return tagged_output

In [5]:
class MarkupLMDataset(Dataset):
    """Dataset for token classification with MarkupLM."""

    def __init__(self, data, processor=None, max_length=512, is_train=True):
        self.data = data
        self.is_train = is_train
        self.processor = processor
        self.max_length = max_length
        self.encodings = []
        self.get_encoding_windows()
        

    def get_encoding_windows(self):
        """Splits the tokenized input into windows of 512 tokens"""
                
        for item in self.data:            
            if self.is_train:
                nodes, xpaths, node_labels = (
                    item["nodes"],
                    item["xpaths"],
                    item['node_labels']
                )
            else:
                nodes, xpaths, node_labels = (
                    item["nodes"],
                    item["xpaths"],
                    None
                )                
            
            # provide encoding to processor
            encoding = self.processor(
                nodes=nodes,
                xpaths=xpaths,
                node_labels=node_labels,
                padding="max_length",
                max_length=self.max_length,
                return_tensors="pt",
                truncation=False,
                return_offsets_mapping=True
            )

            # remove batch dimension
            encoding = {k: v.squeeze() for k, v in encoding.items()}

            # chunk up the encoding sequences to that it is less than the 
            # max input length of 512 tokens
            if not self.is_train:
                #num_tokens = len(item['nodes'][0])
                num_tokens = len(encoding['input_ids'])
                
                for idx in range(0, num_tokens, self.max_length):
                    batch_encoding = {}
                    for k, v in encoding.items():
                        batch_encoding[k] = v[idx: idx + self.max_length]

                    self.encodings.append(batch_encoding)                    
                    continue
            
            else:
                if len(encoding["input_ids"]) <= self.max_length:                    
                    self.encodings.append(encoding)
                    continue

                else:
                    batch_encoding = {}

                    start_idx, end_idx = 0, self.max_length

                    while end_idx < len(encoding["input_ids"]):
                        # decrement the end_idx by 1 until the label is not -100
                        while encoding["labels"][end_idx] == -100:
                            end_idx = end_idx - 1

                            # if the end idx is equal to the start idx meaning
                            # we don't encounter a non -100 token,
                            # we set window size as the max_length
                            if end_idx == start_idx:
                                end_idx = start_idx + self.max_length
                                break

                        for k, v in encoding.items():
                            batch_encoding[k] = v[start_idx:end_idx]

                        self.encodings.append(batch_encoding)
                        batch_encoding = {}

                        # update the pointers
                        start_idx = end_idx
                        end_idx = end_idx + self.max_length

                    # collect the remaining tokens
                    for k, v in encoding.items():
                        batch_encoding[k] = v[start_idx:]

                    if batch_encoding:
                        self.encodings.append(batch_encoding)

    def __len__(self):
        return len(self.encodings)

    def __getitem__(self, idx):
        # first, get nodes, xpaths and node labels
        item = self.encodings[idx]

        # pad the encodings to max_length of 512 tokens
        padded_item = self.processor.tokenizer.pad(
            item, max_length=self.max_length, padding="max_length", return_tensors="pt"
        )

        return padded_item

## Define Inference Loop and Main Function Execution

In [6]:
logits = {1:[]}
def run_inference_loop(dataloader, model, device, config, processor):
    '''Runs eval loop for entire dataset

    Args:
        dataloader: torch.utils.data.DataLoader: iterator over Dataset object
        model: transformers.PreTrainedModel. fine-tuned MarkupLM model
        device: torch.device. Specifies whether GPU is available for computation
        label_list: list. List of labels used to train the MarkupLM model
        config: dict. Contains user-provided params and args

    Returns:
        None
    '''
    model.eval()
    
    results = {"nodes": [], "preds": []}
    for batch in tqdm(dataloader, desc='inference_loop'):
        # get the inputs;
        inputs = {k: v.to(device) for k, v in batch.items()}

        # if ablation mode is set to true then
        # either mask the xpaths or shuffle them
        if config["ablation"]["run_ablation"]:
            inputs = utils.ablation(config, inputs)

        # get the offset mapping. It contains the spans of the 
        # words that were split during tokenization. 
        # Info present at a token level
        offset_mapping = inputs.pop('offset_mapping').squeeze().tolist()

        # forward + backward + optimize
        outputs = model(**inputs)

        predictions = outputs.logits.argmax(dim=-1)
        #logits[1] = outputs.logits
        #print(predictions.squeeze().tolist())
        pred_labels = [model.config.id2label[idx] for idx in predictions.squeeze().tolist()]
        
        input_ids = inputs['input_ids'].detach().numpy().flatten().tolist()
        input_word_pieces = [processor.decode([id]) for id in input_ids]
        
        
        # input_ids = [x for x in input_ids if x not in special_tokens]
        # print(input_ids)
        results['nodes'].append(input_word_pieces)
        results['preds'].append(pred_labels)
                                
    return results

In [7]:
def main(config, test_data, model_ckpt_path=None, is_train=False):
    '''Main execution of script'''
    # get the  list of labels along with the label to id mapping and
    # reverse mapping
    label_list, id2label, label2id = utils.get_label_list(config)
    
    # define the processor and model
    if config["model"]["use_large_model"]:
        processor = MarkupLMProcessor.from_pretrained(
            "microsoft/markuplm-large",
            only_label_first_subword=config['model']['label_only_first_subword']
        )
        model = MarkupLMForTokenClassification.from_pretrained(
            "microsoft/markuplm-large", id2label=id2label, label2id=label2id
        )

    else:
        processor = MarkupLMProcessor.from_pretrained(
            "microsoft/markuplm-base",
            only_label_first_subword=config['model']['label_only_first_subword'],
        )
        model = MarkupLMForTokenClassification.from_pretrained(
            "microsoft/markuplm-base", id2label=id2label, label2id=label2id
        )
        
    if model_ckpt_path is not None:
        model_ckpt = torch.load(model_ckpt_path, 
                                map_location='cpu')
        
        model.load_state_dict(model_ckpt)
        

    processor.parse_html = False
    
    # convert the input dataset
    # to torch datasets. Create the dataloaders as well
    test_dataset = MarkupLMDataset(
        data=test_data,
        processor=processor,
        max_length=config["model"]["max_length"],
        is_train=is_train
    )
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False
    )

    model.to(device)  # move to GPU if available

    print("*" * 50)
    print(f'Running Inference Loop!')
    print("*" * 50)

    # run inference loop
    results = run_inference_loop(test_dataloader, model, device, 
                                     config, processor)


    return results

# Inference Execution

## Specify the input params

In [8]:
config_path = './configs/config.yaml'

test_contract_dir = "../contracts/test"

# if loading from ckpt then change the line below
model_ckpt_path = "/Users/pranabislam/Desktop/research/MarkupMnA-Markup-Based-Segmentation-of-MnA-Agreements/markup-mna/models/pretrained_models/markuplm_base_model_ablation_shuffle_num_contract_100pct_f1-0.871.pt"

max_length = 512
test_batch_size = 1

In [9]:
# read the config file 
with open(config_path, 'r') as yml:
    config = yaml.safe_load(yml)

In [10]:
label_list, id2label, label2id = utils.get_label_list(config)
num_labels = len(label2id)

In [43]:
test_contracts = glob.glob(os.path.join(test_contract_dir, "*.csv"))

test_contracts = [test_contracts[5]]

print(f"Found {len(test_contracts)} in test dir")

Found 1 in test dir


In [44]:
test_contracts

['../contracts/test/contract_92.csv']

In [45]:
test_data = [] 
for tagged_path in test_contracts:
    tagged_output = create_raw_dataset(tagged_path, 
                                       id2label=id2label, 
                                       label2id=label2id,
                                       is_train=False)

    test_data.append(tagged_output)

### Run the inference pipeline

In [46]:
results = main(config, test_data, model_ckpt_path=model_ckpt_path,is_train=False)

Some weights of the model checkpoint at microsoft/markuplm-base were not used when initializing MarkupLMForTokenClassification: ['nrp_cls.decoder.bias', 'markuplm.pooler.dense.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'nrp_cls.dense.bias', 'nrp_cls.dense.weight', 'nrp_cls.LayerNorm.bias', 'ptc_cls.weight', 'cls.predictions.bias', 'ptc_cls.bias', 'cls.predictions.transform.LayerNorm.bias', 'markuplm.pooler.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'nrp_cls.LayerNorm.weight', 'nrp_cls.decoder.weight']
- This IS expected if you are initializing MarkupLMForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MarkupLMForTokenClassification from the checkpoint of a model that 

**************************************************
Running Inference Loop!
**************************************************


inference_loop:   0%|                                   | 0/159 [00:00<?, ?it/s]You're using a MarkupLMTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
inference_loop: 100%|█████████████████████████| 159/159 [01:40<00:00,  1.58it/s]


In [47]:
df = pd.DataFrame.from_dict(results)

In [48]:
df.head()

Unnamed: 0,nodes,preds
0,"[<s>, EX, -, 2, ., 1, 2, d, 36, 14, 18, d, ex,...","[o, o, o, s_ssn, o, o, o, o, s_n, o, o, o, o, ..."
1,"[Material, Contracts, 39, Section, , 3, ., 2...","[o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, ..."
2,"[7, ., 10, Take, over, Stat, utes, 68, Sectio...","[e_ssn, o, s_n, o, o, o, o, o, o, o, o, o, o, ..."
3,"[ Company, ,, with, the, Company, survivin...","[o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, ..."
4,"[ of, the, applicable, Support, Agreement,...","[o, o, o, o, o, o, o, o, o, o, o, o, o, o, o, ..."


In [49]:
df = df.explode(['nodes', 'preds']).reset_index(drop=True)

In [50]:
df.head()

Unnamed: 0,nodes,preds
0,<s>,o
1,EX,o
2,-,o
3,2,s_ssn
4,.,o


In [51]:
df

Unnamed: 0,nodes,preds
0,<s>,o
1,EX,o
2,-,o
3,2,s_ssn
4,.,o
...,...,...
81403,<pad>,o
81404,<pad>,o
81405,<pad>,o
81406,<pad>,o


## Let's try to match the predictions with the visualizer now (need some way to glue back to xpaths)

In [52]:
#inference_csv_path = '/Users/pranabislam/Desktop/research/MarkupMnA-Markup-Based-Segmentation-of-MnA-Agreements/contracts/test/contract_81.csv'
inference_csv_path = test_contracts[0]
inference_csv = pd.read_csv(inference_csv_path)[['xpaths','text']].copy()
inference_csv

Unnamed: 0,xpaths,text
0,/html/body/document/type,EX-2.1
1,/html/body/document/type/sequence,2
2,/html/body/document/type/sequence/filename,d361418dex21.htm
3,/html/body/document/type/sequence/filename/des...,EX-2.1
4,/html/body/document/type/sequence/filename/des...,EX-2.1
...,...,...
3469,/html/body/document/type/sequence/filename/des...,[ ]
3470,/html/body/document/type/sequence/filename/des...,By:
3471,/html/body/document/type/sequence/filename/des...,Name:
3472,/html/body/document/type/sequence/filename/des...,"Secretary, [ ]"


In [53]:
def create_repeat_xpaths(processor, row):
    '''
    Given a row in a dataset, create a list of repeating xpaths with length of the number of tokens
    in the text row
    '''
    enc = processor(
        nodes=[row.text],
        xpaths=[row.xpaths],
        #node_labels=node_labels,
        #padding="max_length",
        max_length=512,
        return_tensors="pt",
        truncation=False,
    )
    
    num_tokens = enc.input_ids.shape[1] - 2
    tokens = [processor.decode(x, skip_special_tokens=False) for x in enc.input_ids[0][1:-1]]
    
    return [[row.xpaths] * num_tokens, tokens]

In [54]:
processor = MarkupLMProcessor.from_pretrained(
    "microsoft/markuplm-base",
    only_label_first_subword=True
)
processor.parse_html = False

In [55]:
g = inference_csv.apply(lambda row: create_repeat_xpaths(processor, row), axis=1)

In [56]:
g

0       [[/html/body/document/type, /html/body/documen...
1              [[/html/body/document/type/sequence], [2]]
2       [[/html/body/document/type/sequence/filename, ...
3       [[/html/body/document/type/sequence/filename/d...
4       [[/html/body/document/type/sequence/filename/d...
                              ...                        
3469    [[/html/body/document/type/sequence/filename/d...
3470    [[/html/body/document/type/sequence/filename/d...
3471    [[/html/body/document/type/sequence/filename/d...
3472    [[/html/body/document/type/sequence/filename/d...
3473    [[/html/body/document/type/sequence/filename/d...
Length: 3474, dtype: object

In [57]:
inference_csv['xpaths_list'] = g.apply(lambda x: x[0])
inference_csv['text_list'] = g.apply(lambda x: x[1])

In [58]:
exploded = inference_csv.explode(['xpaths_list', 'text_list']).reset_index(drop=True)
exploded

Unnamed: 0,xpaths,text,xpaths_list,text_list
0,/html/body/document/type,EX-2.1,/html/body/document/type,EX
1,/html/body/document/type,EX-2.1,/html/body/document/type,-
2,/html/body/document/type,EX-2.1,/html/body/document/type,2
3,/html/body/document/type,EX-2.1,/html/body/document/type,.
4,/html/body/document/type,EX-2.1,/html/body/document/type,1
...,...,...,...,...
81140,/html/body/document/type/sequence/filename/des...,"Secretary, [ ]",/html/body/document/type/sequence/filename/des...,
81141,/html/body/document/type/sequence/filename/des...,"Secretary, [ ]",/html/body/document/type/sequence/filename/des...,
81142,/html/body/document/type/sequence/filename/des...,"Secretary, [ ]",/html/body/document/type/sequence/filename/des...,
81143,/html/body/document/type/sequence/filename/des...,"Secretary, [ ]",/html/body/document/type/sequence/filename/des...,]


In [59]:
df

Unnamed: 0,nodes,preds
0,<s>,o
1,EX,o
2,-,o
3,2,s_ssn
4,.,o
...,...,...
81403,<pad>,o
81404,<pad>,o
81405,<pad>,o
81406,<pad>,o


Looks like pad tokens are included here...

In [60]:
df.query("nodes == '<s>' | nodes == '</s>'")

Unnamed: 0,nodes,preds
0,<s>,o
81146,</s>,o


In [62]:
filter_df_based_on_start_and_end_token = df.query("nodes == '<s>' | nodes == '</s>'")
start_idx = filter_df_based_on_start_and_end_token.index.min() + 1
end_idx = filter_df_based_on_start_and_end_token.index.max()

In [63]:
start_idx, end_idx

(1, 81146)

In [64]:
preds = df.iloc[start_idx:end_idx].copy()

In [65]:
preds

Unnamed: 0,nodes,preds
1,EX,o
2,-,o
3,2,s_ssn
4,.,o
5,1,o
...,...,...
81141,,o
81142,,o
81143,,o
81144,],o


In [66]:
exploded

Unnamed: 0,xpaths,text,xpaths_list,text_list
0,/html/body/document/type,EX-2.1,/html/body/document/type,EX
1,/html/body/document/type,EX-2.1,/html/body/document/type,-
2,/html/body/document/type,EX-2.1,/html/body/document/type,2
3,/html/body/document/type,EX-2.1,/html/body/document/type,.
4,/html/body/document/type,EX-2.1,/html/body/document/type,1
...,...,...,...,...
81140,/html/body/document/type/sequence/filename/des...,"Secretary, [ ]",/html/body/document/type/sequence/filename/des...,
81141,/html/body/document/type/sequence/filename/des...,"Secretary, [ ]",/html/body/document/type/sequence/filename/des...,
81142,/html/body/document/type/sequence/filename/des...,"Secretary, [ ]",/html/body/document/type/sequence/filename/des...,
81143,/html/body/document/type/sequence/filename/des...,"Secretary, [ ]",/html/body/document/type/sequence/filename/des...,]


In [67]:
#exploded_partial = exploded.iloc[:len(preds)].copy().reset_index(drop=True)
#exploded_partial
exploded_partial = exploded.copy()

In [68]:
exploded_partial['nodes'] = preds['nodes'].to_list()
exploded_partial['preds'] = preds['preds'].to_list()
exploded_partial

Unnamed: 0,xpaths,text,xpaths_list,text_list,nodes,preds
0,/html/body/document/type,EX-2.1,/html/body/document/type,EX,EX,o
1,/html/body/document/type,EX-2.1,/html/body/document/type,-,-,o
2,/html/body/document/type,EX-2.1,/html/body/document/type,2,2,s_ssn
3,/html/body/document/type,EX-2.1,/html/body/document/type,.,.,o
4,/html/body/document/type,EX-2.1,/html/body/document/type,1,1,o
...,...,...,...,...,...,...
81140,/html/body/document/type/sequence/filename/des...,"Secretary, [ ]",/html/body/document/type/sequence/filename/des...,,,o
81141,/html/body/document/type/sequence/filename/des...,"Secretary, [ ]",/html/body/document/type/sequence/filename/des...,,,o
81142,/html/body/document/type/sequence/filename/des...,"Secretary, [ ]",/html/body/document/type/sequence/filename/des...,,,o
81143,/html/body/document/type/sequence/filename/des...,"Secretary, [ ]",/html/body/document/type/sequence/filename/des...,],],o


In [69]:
sum(exploded_partial.text_list == exploded_partial.nodes) == len(exploded_partial)

True

In [70]:
(exploded_partial['xpaths'] == exploded_partial['xpaths_list']).sum() == len(exploded_partial)

True

In [71]:
exploded_partial[['xpaths','text','nodes','preds']].to_json('testing_contract_92.json',orient='records')