In [1]:
# !unzip "/content/drive/MyDrive/LayoutLMv2/dataset/CORD.zip" -d '/content/dataset'/

In [2]:
# !pip install -q git+https://github.com/huggingface/transformers.git

In [3]:
# !pip install -q datasets seqeval
# !pip install -q pyyaml==5.1

In [4]:
# !pip install -q detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.10/index.html

# Prepare dataset

In [5]:
import pandas as pd

train = pd.read_pickle('/content/drive/MyDrive/LayoutLMv2/dataset/CORD_layoutlmv2_format/train.pkl')
val = pd.read_pickle('/content/drive/MyDrive/LayoutLMv2/dataset/CORD_layoutlmv2_format/dev.pkl')
test = pd.read_pickle('/content/drive/MyDrive/LayoutLMv2/dataset/CORD_layoutlmv2_format/test.pkl')

In [6]:
from collections import Counter

all_labels = [item  for sublist in train[1] for item in sublist] + [item  for sublist in val[1] for item in sublist] + [item  for sublist in test[1] for item in sublist]
Counter(all_labels)

Counter({'menu.cnt': 2429,
         'menu.discountprice': 403,
         'menu.etc': 19,
         'menu.itemsubtotal': 7,
         'menu.nm': 6597,
         'menu.num': 109,
         'menu.price': 2585,
         'menu.sub_cnt': 189,
         'menu.sub_etc': 9,
         'menu.sub_nm': 822,
         'menu.sub_price': 160,
         'menu.sub_unitprice': 14,
         'menu.unitprice': 750,
         'menu.vatyn': 9,
         'sub_total.discount_price': 191,
         'sub_total.etc': 283,
         'sub_total.othersvc_price': 6,
         'sub_total.service_price': 353,
         'sub_total.subtotal_price': 1482,
         'sub_total.tax_price': 1283,
         'total.cashprice': 1393,
         'total.changeprice': 1297,
         'total.creditcardprice': 410,
         'total.emoneyprice': 129,
         'total.menuqty_cnt': 630,
         'total.menutype_cnt': 130,
         'total.total_etc': 89,
         'total.total_price': 2120,
         'void_menu.nm': 3,
         'void_menu.price': 1})

In [7]:
replacing_labels = {'menu.etc': 'O', 'mneu.itemsubtotal': 'O', 'menu.sub_etc': 'O', 'menu.sub_unitprice': 'O', 'menu.vatyn': 'O',
                  'void_menu.nm': 'O', 'void_menu.price': 'O', 'sub_total.othersvc_price': 'O'}

In [8]:
def replace_elem(elem):
    try:
        return replacing_labels[elem]
    except:
        return elem

def replace_list(ls):
    return [replace_elem(elem) for elem in ls]

train[1] = [replace_list(ls) for ls in train[1]]
val[1] = [replace_list(ls) for ls in val[1]]
test[1] = [replace_list(ls) for ls in test[1]]

In [9]:
all_labels = [item for sublist in train[1] for item in sublist] + [item for sublist in val[1] for item in sublist] + [item for sublist in test[1] for item in sublist]
Counter(all_labels)

Counter({'O': 61,
         'menu.cnt': 2429,
         'menu.discountprice': 403,
         'menu.itemsubtotal': 7,
         'menu.nm': 6597,
         'menu.num': 109,
         'menu.price': 2585,
         'menu.sub_cnt': 189,
         'menu.sub_nm': 822,
         'menu.sub_price': 160,
         'menu.unitprice': 750,
         'sub_total.discount_price': 191,
         'sub_total.etc': 283,
         'sub_total.service_price': 353,
         'sub_total.subtotal_price': 1482,
         'sub_total.tax_price': 1283,
         'total.cashprice': 1393,
         'total.changeprice': 1297,
         'total.creditcardprice': 410,
         'total.emoneyprice': 129,
         'total.menuqty_cnt': 630,
         'total.menutype_cnt': 130,
         'total.total_etc': 89,
         'total.total_price': 2120})

In [10]:
labels = list(set(all_labels))
print(labels)

['menu.num', 'total.menutype_cnt', 'sub_total.etc', 'O', 'menu.discountprice', 'menu.sub_cnt', 'menu.price', 'total.total_price', 'total.changeprice', 'sub_total.discount_price', 'menu.cnt', 'total.cashprice', 'sub_total.tax_price', 'menu.sub_price', 'sub_total.subtotal_price', 'total.emoneyprice', 'total.total_etc', 'menu.nm', 'menu.unitprice', 'total.creditcardprice', 'menu.sub_nm', 'total.menuqty_cnt', 'sub_total.service_price', 'menu.itemsubtotal']


In [11]:
label2id = {label: idx for idx, label in enumerate(labels)}
id2label = {idx: label for idx, label in enumerate(labels)}
print(label2id)
print(id2label)

{'menu.num': 0, 'total.menutype_cnt': 1, 'sub_total.etc': 2, 'O': 3, 'menu.discountprice': 4, 'menu.sub_cnt': 5, 'menu.price': 6, 'total.total_price': 7, 'total.changeprice': 8, 'sub_total.discount_price': 9, 'menu.cnt': 10, 'total.cashprice': 11, 'sub_total.tax_price': 12, 'menu.sub_price': 13, 'sub_total.subtotal_price': 14, 'total.emoneyprice': 15, 'total.total_etc': 16, 'menu.nm': 17, 'menu.unitprice': 18, 'total.creditcardprice': 19, 'menu.sub_nm': 20, 'total.menuqty_cnt': 21, 'sub_total.service_price': 22, 'menu.itemsubtotal': 23}
{0: 'menu.num', 1: 'total.menutype_cnt', 2: 'sub_total.etc', 3: 'O', 4: 'menu.discountprice', 5: 'menu.sub_cnt', 6: 'menu.price', 7: 'total.total_price', 8: 'total.changeprice', 9: 'sub_total.discount_price', 10: 'menu.cnt', 11: 'total.cashprice', 12: 'sub_total.tax_price', 13: 'menu.sub_price', 14: 'sub_total.subtotal_price', 15: 'total.emoneyprice', 16: 'total.total_etc', 17: 'menu.nm', 18: 'menu.unitprice', 19: 'total.creditcardprice', 20: 'menu.sub_

In [12]:
import os
from torch.utils.data import Dataset
import torch
from PIL import Image

class CORDDataset(Dataset):
    def __init__(self, annotations, image_dir, processor=None, max_length=512):
        """
        Args:
            annotations (List[List]): List of lists containing the word-level annotations (words, labels, boxes)
            image_dir (string): Directory with all the document images
            processor (LayoutLMv2Processor): Processor to prepare the text + image
        """
        self.words, self.labels, self.boxes, self.file_names = annotations
        self.image_dir = image_dir
        self.image_file_names = [file_name.replace('.json', '.png') for file_name in self.file_names]
        assert len(self.image_file_names) == len(self.words)
        self.processor = processor
    
    def __len__(self):
        return len(self.image_file_names)
    
    def __getitem__(self, idx):
        image_name = self.image_file_names[idx]
        image = Image.open(self.image_dir + image_name).convert('RGB')

        words = self.words[idx]
        boxes = self.boxes[idx]
        word_labels = self.labels[idx]

        assert len(words) == len(boxes) == len(word_labels)

        word_labels = [label2id[label] for label in word_labels]

        encoded_inputs = self.processor(image, words, boxes=boxes, word_labels=word_labels,
                                       padding="max_length", truncation=True,
                                       return_tensors="pt")
        
        # remove batch dimension
        for k,v in encoded_inputs.items():
            encoded_inputs[k] = v.squeeze()
        
        assert encoded_inputs.input_ids.shape == torch.Size([512])
        assert encoded_inputs.attention_mask.shape == torch.Size([512])
        assert encoded_inputs.token_type_ids.shape == torch.Size([512])
        assert encoded_inputs.bbox.shape == torch.Size([512, 4])
        assert encoded_inputs.image.shape == torch.Size([3, 224, 224])
        assert encoded_inputs.labels.shape == torch.Size([512])

        return encoded_inputs

In [13]:
from transformers import LayoutLMv2Processor

processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")

train_dataset = CORDDataset(annotations=train,
                            image_dir='/content/dataset/CORD/train/image/', 
                            processor=processor)
val_dataset = CORDDataset(annotations=val,
                            image_dir='/content/dataset/CORD/dev/image/', 
                            processor=processor)
test_dataset = CORDDataset(annotations=test,
                            image_dir='/content/dataset/CORD/test/image/', 
                            processor=processor)

In [14]:
encoding = train_dataset[0]
encoding.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'bbox', 'labels', 'image'])

In [15]:
for k,v in encoding.items():
  print(k, v.shape)

input_ids torch.Size([512])
token_type_ids torch.Size([512])
attention_mask torch.Size([512])
bbox torch.Size([512, 4])
labels torch.Size([512])
image torch.Size([3, 224, 224])


In [16]:
print(processor.tokenizer.decode(encoding['input_ids']))

[CLS] 1 kfc day 34, 545 1 charge ta 909 sub total 35, 454 p. rest 10 % 3, 546 total 39, 000 cash 100, 000 kembali 61, 000 2 items, [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

In [17]:
train[0][0]

['1',
 'KFC',
 'DAY',
 '34,545',
 '1',
 'CHARGE',
 'TA',
 '909',
 'Sub',
 'Total',
 '35,454',
 'P.Rest',
 '10',
 '%',
 '3,546',
 'Total',
 '39,000',
 'Cash',
 '100,000',
 'kembali',
 '61,000',
 '2',
 'Items,']

In [18]:
train[1][0]

['menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'sub_total.subtotal_price',
 'sub_total.subtotal_price',
 'sub_total.subtotal_price',
 'sub_total.tax_price',
 'sub_total.tax_price',
 'sub_total.tax_price',
 'sub_total.tax_price',
 'total.total_price',
 'total.total_price',
 'total.cashprice',
 'total.cashprice',
 'total.changeprice',
 'total.changeprice',
 'total.menuqty_cnt',
 'total.menuqty_cnt']

In [19]:
[id2label[label] for label in encoding['labels'].tolist() if label != -100]

['menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'menu.cnt',
 'menu.nm',
 'menu.nm',
 'menu.price',
 'sub_total.subtotal_price',
 'sub_total.subtotal_price',
 'sub_total.subtotal_price',
 'sub_total.tax_price',
 'sub_total.tax_price',
 'sub_total.tax_price',
 'sub_total.tax_price',
 'total.total_price',
 'total.total_price',
 'total.cashprice',
 'total.cashprice',
 'total.changeprice',
 'total.changeprice',
 'total.menuqty_cnt',
 'total.menuqty_cnt']

In [20]:
for id, label in zip(encoding['input_ids'][:30], encoding['labels'][:30]):
  print(processor.tokenizer.decode([id]), label.item())

[CLS] -100
1 10
k 17
##fc -100
day 17
34 6
, -100
54 -100
##5 -100
1 10
charge 17
ta 17
90 6
##9 -100
sub 14
total 14
35 14
, -100
45 -100
##4 -100
p 12
. -100
rest -100
10 12
% 12
3 12
, -100
54 -100
##6 -100
total 7


In [21]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=2)

In [22]:
torch.cuda.empty_cache()

# Train model

In [23]:
from transformers import LayoutLMv2ForTokenClassification, AdamW
import torch
from tqdm.notebook import tqdm

model = LayoutLMv2ForTokenClassification.from_pretrained('microsoft/layoutlmv2-base-uncased',
                                                                      num_labels=len(labels))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)

global_step = 0
num_train_epochs = 10

# put the model in training mode
model.train()
for epoch in range(num_train_epochs):
    print("Epoch:", epoch)
    for batch in tqdm(train_dataloader):
        input_ids = batch['input_ids'].to(device)
        bbox = batch['bbox'].to(device)
        image = batch['image'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimizer
        outputs = model(input_ids=input_ids,
                        bbox=bbox,
                        image=image,
                        attention_mask=attention_mask,
                        labels=labels)
        loss = outputs.loss

        if global_step % 100 == 0:
            print(f'Loss after {global_step} steps: {loss.item()}')
        
        loss.backward()
        optimizer.step()
        global_step += 1

model.save_pretrained('/content/drive/MyDrive/LayoutLMv2/Checkpoints')

Some weights of the model checkpoint at microsoft/layoutlmv2-base-uncased were not used when initializing LayoutLMv2ForTokenClassification: ['layoutlmv2.visual.backbone.bottom_up.res4.2.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.17.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.1.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.1.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.12.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.2.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res5.0.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.15.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.21.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.18.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.0.conv2.norm.num_batches_tracke

Epoch: 0




  0%|          | 0/100 [00:00<?, ?it/s]

  // self.config.image_feature_pool_shape[1]
  // self.config.image_feature_pool_shape[0]


Loss after 0 steps: 3.186067581176758
Epoch: 1


  0%|          | 0/100 [00:00<?, ?it/s]

Loss after 100 steps: 1.63863205909729
Epoch: 2


  0%|          | 0/100 [00:00<?, ?it/s]

Loss after 200 steps: 1.1427818536758423
Epoch: 3


  0%|          | 0/100 [00:00<?, ?it/s]

Loss after 300 steps: 0.8650424480438232
Epoch: 4


  0%|          | 0/100 [00:00<?, ?it/s]

Loss after 400 steps: 0.5492355823516846
Epoch: 5


  0%|          | 0/100 [00:00<?, ?it/s]

Loss after 500 steps: 0.4166659414768219
Epoch: 6


  0%|          | 0/100 [00:00<?, ?it/s]

Loss after 600 steps: 0.3413824439048767
Epoch: 7


  0%|          | 0/100 [00:00<?, ?it/s]

Loss after 700 steps: 0.294548898935318
Epoch: 8


  0%|          | 0/100 [00:00<?, ?it/s]

Loss after 800 steps: 0.22747154533863068
Epoch: 9


  0%|          | 0/100 [00:00<?, ?it/s]

Loss after 900 steps: 0.2656245827674866


# Evaluation

In [24]:
encoding = test_dataset[0]
processor.tokenizer.decode(encoding['input_ids'])

'[CLS] kupon 15 100, 000 add chicken box 909 subtotal 100, 909 pb1 ( 10 % ) 10, 091 total 111, 000 cash rp. 111, 000 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [

In [25]:
ground_truth_labels = [id2label[label] for label in encoding['labels'].squeeze().tolist() if label != -100]
print(ground_truth_labels)

['menu.nm', 'menu.nm', 'menu.price', 'menu.sub_nm', 'menu.sub_nm', 'menu.sub_nm', 'menu.sub_price', 'sub_total.subtotal_price', 'sub_total.subtotal_price', 'sub_total.tax_price', 'sub_total.tax_price', 'sub_total.tax_price', 'total.total_price', 'total.total_price', 'total.cashprice', 'total.cashprice']


In [26]:
for k,v in encoding.items():
    encoding[k] = v.unsqueeze(0).to(device)

model.eval()
outputs = model(input_ids=encoding['input_ids'], attention_mask=encoding['attention_mask'],
                token_type_ids=encoding['token_type_ids'], bbox=encoding['bbox'],
                image=encoding['image'])

  // self.config.image_feature_pool_shape[1]
  // self.config.image_feature_pool_shape[0]


In [27]:
prediction_indices = outputs.logits.argmax(-1).squeeze().tolist()
print(prediction_indices)

[17, 17, 17, 17, 6, 6, 6, 20, 20, 20, 13, 13, 14, 14, 14, 14, 14, 14, 14, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 7, 7, 7, 7, 11, 11, 11, 11, 11, 11, 11, 17, 21, 21, 11, 11, 11, 11, 11, 21, 21, 21, 21, 11, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 11, 8, 11, 21, 21, 8, 11, 11, 8, 8, 11, 8, 11, 8, 8, 21, 21, 21, 21, 21, 21, 8, 8, 21, 8, 21, 8, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 11, 8, 21, 21, 21, 21, 8, 21, 8, 8, 21, 21, 21, 21, 21, 21, 21, 21, 21, 8, 21, 8, 8, 8, 8, 11, 11, 8, 11, 8, 2, 8, 17, 17, 8, 8, 8, 17, 2, 8, 8, 8, 8, 8, 21, 11, 11, 8, 8, 8, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 21, 11, 11, 11, 11, 11, 21, 11, 11, 11, 11, 11, 11, 11, 21, 11, 11, 11, 11, 11, 11, 21, 11, 11, 21, 11, 21, 8, 8, 8, 11, 8, 8, 11, 11, 11, 11, 11, 11, 11, 11, 11, 20, 17, 17, 20, 17, 17, 8, 8, 2, 11, 8, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,

In [28]:
prediction_indices = outputs.logits.argmax(-1).squeeze().tolist()
predictions = [id2label[label] for gt, label in zip(encoding['labels'].squeeze().tolist(), prediction_indices) if gt != -100]
print(predictions)

['menu.nm', 'menu.nm', 'menu.price', 'menu.sub_nm', 'menu.sub_nm', 'menu.sub_nm', 'menu.sub_price', 'sub_total.subtotal_price', 'sub_total.subtotal_price', 'sub_total.tax_price', 'sub_total.tax_price', 'sub_total.tax_price', 'total.total_price', 'total.total_price', 'total.cashprice', 'total.cashprice']


In [29]:
import numpy as np

preds_val = None
out_label_ids = None

# put model in evaluation mode
model.eval()
for batch in tqdm(test_dataloader, desc="Evaluating"):
    with torch.no_grad():
        input_ids = batch['input_ids'].to(device)
        bbox = batch['bbox'].to(device)
        image = batch['image'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        labels = batch['labels'].to(device)

        # forward pass
        outputs = model(input_ids=input_ids, bbox=bbox, image=image, attention_mask=attention_mask, 
                        token_type_ids=token_type_ids, labels=labels)
        
        if preds_val is None:
          preds_val = outputs.logits.detach().cpu().numpy()
          out_label_ids = batch["labels"].detach().cpu().numpy()
        else:
          preds_val = np.append(preds_val, outputs.logits.detach().cpu().numpy(), axis=0)
          out_label_ids = np.append(
              out_label_ids, batch["labels"].detach().cpu().numpy(), axis=0
          )

Evaluating:   0%|          | 0/50 [00:00<?, ?it/s]

  // self.config.image_feature_pool_shape[1]
  // self.config.image_feature_pool_shape[0]


In [30]:
import warnings
warnings.filterwarnings("ignore")
from seqeval.metrics import (
    classification_report,
    f1_score,
    precision_score,
    recall_score)

def results_test(preds, out_label_ids, labels):
  preds = np.argmax(preds, axis=2)

  label_map = {i: label for i, label in enumerate(labels)}

  out_label_list = [[] for _ in range(out_label_ids.shape[0])]
  preds_list = [[] for _ in range(out_label_ids.shape[0])]

  for i in range(out_label_ids.shape[0]):
      for j in range(out_label_ids.shape[1]):
          if out_label_ids[i, j] != -100:
              out_label_list[i].append(label_map[out_label_ids[i][j]])
              preds_list[i].append(label_map[preds[i][j]])

  results = {
      "precision": precision_score(out_label_list, preds_list),
      "recall": recall_score(out_label_list, preds_list),
      "f1": f1_score(out_label_list, preds_list),
  }
  return results, classification_report(out_label_list, preds_list)

In [31]:
labels = list(set(all_labels))
val_result, class_report = results_test(preds_val, out_label_ids, labels)
print("Overall results:", val_result)
print(class_report)

Overall results: {'precision': 0.9099774943735934, 'recall': 0.9196360879454132, 'f1': 0.9147812971342384}
                         precision    recall  f1-score   support

                enu.cnt       0.98      0.94      0.96       224
      enu.discountprice       0.88      0.70      0.78        10
       enu.itemsubtotal       0.00      0.00      0.00         6
                 enu.nm       0.94      0.82      0.88       251
                enu.num       0.92      1.00      0.96        11
              enu.price       0.98      0.95      0.97       247
            enu.sub_cnt       0.57      1.00      0.72        17
             enu.sub_nm       0.41      0.97      0.57        32
          enu.sub_price       0.67      1.00      0.80        20
          enu.unitprice       0.98      0.94      0.96        68
         otal.cashprice       0.96      0.94      0.95        71
       otal.changeprice       0.97      0.98      0.97        59
   otal.creditcardprice       0.94      0.94   