In [1]:
import pandas as pd 
import torch
from torchmetrics import Accuracy
import transformers
import lightning.pytorch as pl
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from model.model import ResumeParser

In [3]:
from model.utils import label_idx, idx_label

args = {
    'positional_dim': 32,
    'hidden_dim': 256,
    'classifier_dropout': 0.3,
    'num_classes': 4,
    'use_llm': False,
    'n_hidden': 1, # total layers: n_hidden + 2
}

tokenizer_args = {
    'padding': 'max_length',
    'return_tensors': 'pt',
}

In [4]:
from transformers import BertTokenizerFast, BertModel

In [5]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
model = BertModel.from_pretrained("bert-base-cased")

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
parser = ResumeParser.load_from_checkpoint("model/epoch_style.ckpt", backend=model, args=args).eval()

Device:  cpu


In [7]:
in_file = "./data/pdf/jihualde_CV.pdf"

In [8]:
from annotation_object import AnnotationObject, serialize
anno = AnnotationObject(in_file)

Using laparams =  <LAParams: char_margin=2.0, line_margin=0.5, word_margin=0.1 all_texts=False>


30it [00:01, 17.37it/s]


Document has  30 pages
Length of dep: 1298


In [9]:
json_format = anno.json_format
wrapper = anno.wrapper
lines = wrapper.lines
with tqdm(total = anno.n_lines) as pbar:
    while not anno.is_done:
        stk_idx = anno.stack[-1]
        buf_idx = anno.current_idx
        # print(json_format[buf_idx])
        buf_string = None
        stk_string = None
        lbuf = None
        rbuf = None 
        lstk = None 
        rstk = None 
        buf_string = None
        stk_string = None
        lbuf = None
        rbuf = None 
        lstk = None 
        rstk = None 
        hstk = None
        boldbuf = None
        italbuf = None
        boldstk = None
        italstk = None
        hbuf = None
        if(buf_idx == -1):
            buf_string = "$ROOT"
            lbuf = 0
            rbuf = 100
            hbuf = 30
            boldbuf = 0
            italbuf = 0
        else:
            buf_string = "$ROOT" if buf_idx == -1 else json_format[buf_idx]['text']
            lbuf = json_format[buf_idx]['x']
            rbuf = json_format[buf_idx]['x'] + json_format[buf_idx]['width']
            hbuf = int(json_format[buf_idx]['height'])
            try:
                linebuf = lines[json_format[buf_idx]['page']][json_format[buf_idx]['idx_in_page']]
            except:
                print(f"Tried to get line #{json_format[buf_idx]['idx_in_page']} of page {json_format[buf_idx]['page']}; document has {len(wrapper.elements)}/{len(wrapper.lines)} pages, and that page has {len(lines[json_format[buf_idx]['idx_in_page']])} lines")
                raise KeyError
            fontname = linebuf._objs[0].fontname.lower()
            boldbuf = 1 if "bold" in fontname else 0
            italbuf = 1 if "italic" in fontname else 0

        if(stk_idx == -1):
            stk_string = "$ROOT"
            lstk = 0
            rstk = 100
            hstk = 30
            boldstk = 0
            italstk = 0
        else:
            stk_string = "$ROOT" if stk_idx == -1 else json_format[stk_idx]['text']
            lstk = json_format[stk_idx]['x']
            rstk = json_format[stk_idx]['x'] + json_format[stk_idx]['width']
            hstk = int(json_format[stk_idx]['height'])
            linebuf = lines[json_format[stk_idx]['page']][json_format[stk_idx]['idx_in_page']]
            fontname = linebuf._objs[0].fontname.lower()
            boldstk = 1 if "bold" in fontname else 0
            italstk = 1 if "italic" in fontname else 0
        sty = torch.Tensor([[italbuf, boldbuf, italstk, boldstk]]).long()
        pos = torch.floor(torch.Tensor([lbuf, rbuf, lstk, rstk])).long()
        batch = (None, None, pos, sty, None)
        logits = parser.get_logits(batch)
        action_order = (-logits).argsort().squeeze() #largest probabilities first
        for i, action in enumerate(action_order):
            predicted_action = idx_label[action.item()]
            # print(predicted_action)
            if(predicted_action == "discard" and anno.discard() == 0):
                pbar.update(1)
                break
            elif(predicted_action == "merge" and anno.merge_action() == 0):
                pbar.update(1)
                break
            elif(predicted_action == "pop" and anno.pop_action() == 0):
                break
            elif(predicted_action == "subordinate" and anno.subordinate_action() == 0):
                pbar.update(1)
                break
            else:
                assert(i != 3)
        
    
serialize(anno)

100%|██████████| 1298/1298 [00:00<00:00, 9025.23it/s]


Serializing... <annotation_object.AnnotationObject object at 0x16c267d30>
Dumped to ./data/pkl/jihualde_CV.pkl.
