<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/MarkupLM/Fine_tune_MarkupLMForTokenClassification_on_a_custom_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Set-up environment

First, we install 🤗 Transformers.

We also install 🤗 Evaluate and Seqeval, for computing metrics like F1, recall and precision.

In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
torch.cuda.device_count()

1

## Prepare dataset

Next, let's load a toy dataset which we'll use to fine-tune MarkupLM on.

The goal for the model is to label nodes of HTML strings with the appropriate class.

In [3]:
import numpy as np

def sample_websites(dataset, k, seed):
    all_websites = list(dataset.keys())
    training_websites = np.random.choice(all_websites, size=k, replace=False)
    print(training_websites)

    train = []
    valid = []
    for website in all_websites:
        if website in training_websites:
            train += dataset[website]
        else:
            valid += dataset[website]
    
    return train, valid

In [4]:
import json

with open("/home/savkin/vera/all_camera_datasets_merged.json", "r") as file:
    data = json.load(file)

train_data, valid_data = sample_websites(data, k=5, seed=0)

['jr' 'pcnation' 'ecost' 'amazon' 'beachaudio']


In [5]:
print(train_data[0])

{'nodes': [['Canon 3506b001 12.1 Megapixel Waterproof Powershot D10 (3508b001)', 'var runInitFB = true;', '[if lte IE 7]>\n<link href="/site/files/min_iexplorer.css?v=3.4.1" rel="stylesheet" type="text/css" />\n<![endif]', 'Home', '|', 'Order Status', '|', 'Help Center', '|', 'FAQ', '|', 'Return Request', '|', 'Log In', 'SEARCH', 'Entire Site', 'audio', 'video', 'dvd', 'photography', 'communications', 'gaming', 'car', 'computers', 'appliances', 'music & dj', 'blank media', 'office', 'for', 'SHOP BY BRAND', 'audio', 'video', 'dvd', 'photography', 'communications', 'gaming', 'car', 'computers', 'appliances', 'music & dj', 'blank media', 'office', 'More...', 'Home', '»', 'Photography', '»', 'Digital Cameras', '»', '3508B001', '[productwiki]', '[/productwiki]', 'Power Reviews p:1', 'var pr_page_id =\'271456\';\n    var pr_zip_location=\'/powerreviews\';\n    var pr_write_review=\'/review_submission.php?pageId=271456&products_id=271456\';\n\tvar pr_write_answer=\'/question_submission.php?pa

## Create PyTorch Dataset

Next, we'll create a regular PyTorch dataset. Each item of the dataset is an HTML string, encoded using MarkupLMProcessor. Note that we initialize the processor with parse_html = False, as we have already parsed the HTML ourselves and we're providing the nodes, xpaths and node labels.

Note that by default, the processor will only label the first token of a given node and label the remaining tokens with -100. you can change this by setting the `only_label_first_subword` attribute of the processor's tokenizer to `False`.

In [6]:
from torch.utils.data import Dataset

max_length=384

class MarkupLMDataset(Dataset):
    """Dataset for token classification with MarkupLM."""

    def __init__(self, data, processor=None, max_length=max_length):
        self.data = data
        self.processor = processor
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # first, get nodes, xpaths and node labels
        item = self.data[idx]
        nodes, xpaths, node_labels = item['nodes'], item['xpaths'], item['node_labels']

        # provide to processor
        encoding = self.processor(nodes=nodes, 
                                  xpaths=xpaths, 
                                  node_labels=node_labels, 
                                  padding=True, 
                                  truncation=True,
                                  max_length=self.max_length, 
                                  return_tensors="pt")

        # remove batch dimension
        encoding = {k: v.squeeze() for k, v in encoding.items()}

        return encoding

In [7]:
from transformers import MarkupLMProcessor

processor = MarkupLMProcessor.from_pretrained("microsoft/markuplm-base")
processor.parse_html = False

train_dataset = MarkupLMDataset(data=train_data, processor=processor, max_length=max_length)
valid_dataset = MarkupLMDataset(data=valid_data, processor=processor, max_length=max_length)

  from .autonotebook import tqdm as notebook_tqdm


## Create PyTorch Dataloaders

The next step is to create a PyTorch DataLoader, which allows us to get batches from the dataset.

In [8]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
eval_dataloader = DataLoader(valid_dataset, batch_size=64, shuffle=True)

## Define model

We define the model here, which is a MarkupLM-base Transformer, with a token classifier head on top. The token classifier will have randomly initialized weights, while the base Transformer has pre-trained weights.



In [9]:
from transformers import MarkupLMForTokenClassification

id2label = {0: "model", 1: "price", 2: "manufacturer", 3: "other"}
label2id = {label:id for id, label in id2label.items()}

model = MarkupLMForTokenClassification.from_pretrained("microsoft/markuplm-base", id2label=id2label, label2id=label2id)

Some weights of MarkupLMForTokenClassification were not initialized from the model checkpoint at microsoft/markuplm-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


We also create a label_list, where each tag starts with a B (as seqeval expects the labels to be in IOB format).

In [10]:
label_list = ["B-" + x for x in list(id2label.values())]
label_list

['B-model', 'B-price', 'B-manufacturer', 'B-other']

We also define metric calculations (as we'd like to know the F1 score etc. during training). We'll use 🤗 [Evaluate](https://huggingface.co/docs/evaluate/index) for that, which is a library containing many tools for evaluating ML models.

In [11]:
import evaluate

# Metric
metric = evaluate.load("seqeval")

def get_labels(predictions, references):
    # Transform predictions and references tensos to numpy arrays
    if device.type == "cpu":
        y_pred = predictions.detach().clone().numpy()
        y_true = references.detach().clone().numpy()
    else:
        y_pred = predictions.detach().cpu().clone().numpy()
        y_true = references.detach().cpu().clone().numpy()


    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(pred, gold_label) if l != -100]
        for pred, gold_label in zip(y_pred, y_true)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(pred, gold_label) if l != -100]
        for pred, gold_label in zip(y_pred, y_true)
    ]
    return true_predictions, true_labels

def compute_metrics(return_entity_level_metrics=True):
    results = metric.compute()
    if return_entity_level_metrics:
        # Unpack nested dictionaries
        final_results = {}
        for key, value in results.items():
            if isinstance(value, dict):
                for n, v in value.items():
                    final_results[f"{key}_{n}"] = v
            else:
                final_results[key] = value
        return final_results
    else:
        return {
            "precision": results["overall_precision"],
            "recall": results["overall_recall"],
            "f1": results["overall_f1"],
            "accuracy": results["overall_accuracy"],
        }

## Train

Alright, let's train! Here we're training the model in native PyTorch, but of course you could also opt for things like 🤗 Accelerate, 🤗 Trainer, PyTorch Lightning,...

In [12]:
import torch
from torch.optim import AdamW
from tqdm.auto import tqdm

optimizer = AdamW(model.parameters(), lr=2e-5)

model.to(device)

model.train()
for epoch in range(5):  # loop over the dataset multiple times
    for batch in tqdm(train_dataloader):
        # get the inputs;
        inputs = {k:v.to(device) for k,v in batch.items()}

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(**inputs)

        loss = outputs.loss
        loss.backward()
        optimizer.step()

        print("Loss:", loss.item())

        predictions = outputs.logits.argmax(dim=-1)
        labels = batch["labels"]
        preds, refs = get_labels(predictions, labels)
        metric.add_batch(
            predictions=preds,
            references=refs,
        )

    eval_metric = compute_metrics()
    print(f"Epoch {epoch}:", eval_metric)

  2%|▏         | 1/56 [00:05<05:12,  5.69s/it]

Loss: 1.4233579635620117


  4%|▎         | 2/56 [00:09<04:10,  4.63s/it]

Loss: 0.9820742607116699


  5%|▌         | 3/56 [00:14<04:08,  4.69s/it]

Loss: 0.6651310920715332


  7%|▋         | 4/56 [00:18<03:51,  4.44s/it]

Loss: 0.4067155420780182


  9%|▉         | 5/56 [00:22<03:42,  4.37s/it]

Loss: 0.25720328092575073


 11%|█         | 6/56 [00:26<03:37,  4.35s/it]

Loss: 0.14913615584373474


 12%|█▎        | 7/56 [00:31<03:31,  4.31s/it]

Loss: 0.07324621081352234


 14%|█▍        | 8/56 [00:35<03:26,  4.31s/it]

Loss: 0.0403117910027504


 16%|█▌        | 9/56 [00:39<03:23,  4.33s/it]

Loss: 0.025667166337370872


 18%|█▊        | 10/56 [00:44<03:25,  4.46s/it]

Loss: 0.025784146040678024


 20%|█▉        | 11/56 [00:48<03:15,  4.34s/it]

Loss: 0.010956886224448681


 21%|██▏       | 12/56 [00:53<03:11,  4.36s/it]

Loss: 0.008753900416195393


 23%|██▎       | 13/56 [00:57<03:07,  4.35s/it]

Loss: 0.007170962169766426


 25%|██▌       | 14/56 [01:01<03:00,  4.30s/it]

Loss: 0.0058194310404360294


 27%|██▋       | 15/56 [01:05<02:51,  4.18s/it]

Loss: 0.004890337586402893


 29%|██▊       | 16/56 [01:09<02:50,  4.26s/it]

Loss: 0.004167403560131788


 30%|███       | 17/56 [01:14<02:45,  4.24s/it]

Loss: 0.012022174894809723


 32%|███▏      | 18/56 [01:18<02:42,  4.29s/it]

Loss: 0.0030289955902844667


 34%|███▍      | 19/56 [01:23<02:41,  4.35s/it]

Loss: 0.0027100695297122


 36%|███▌      | 20/56 [01:27<02:40,  4.47s/it]

Loss: 0.002364344662055373


 38%|███▊      | 21/56 [01:31<02:31,  4.33s/it]

Loss: 0.0021428142208606005


 39%|███▉      | 22/56 [01:36<02:31,  4.47s/it]

Loss: 0.01331197377294302


 41%|████      | 23/56 [01:40<02:26,  4.43s/it]

Loss: 0.0018129213713109493


 43%|████▎     | 24/56 [01:45<02:19,  4.35s/it]

Loss: 0.0016334232641384006


 45%|████▍     | 25/56 [01:49<02:15,  4.38s/it]

Loss: 0.001509722787886858


 46%|████▋     | 26/56 [01:53<02:08,  4.27s/it]

Loss: 0.011372281238436699


 48%|████▊     | 27/56 [01:58<02:06,  4.37s/it]

Loss: 0.011744876392185688


 50%|█████     | 28/56 [02:02<02:03,  4.40s/it]

Loss: 0.001267187763005495


 52%|█████▏    | 29/56 [02:06<01:57,  4.36s/it]

Loss: 0.0012187809916213155


 54%|█████▎    | 30/56 [02:11<01:53,  4.35s/it]

Loss: 0.014713110402226448


 55%|█████▌    | 31/56 [02:15<01:48,  4.35s/it]

Loss: 0.0011351207504048944


 57%|█████▋    | 32/56 [02:20<01:45,  4.39s/it]

Loss: 0.013902615755796432


 59%|█████▉    | 33/56 [02:24<01:40,  4.36s/it]

Loss: 0.0010451781563460827


 61%|██████    | 34/56 [02:28<01:32,  4.22s/it]

Loss: 0.0010082258377224207


 62%|██████▎   | 35/56 [02:32<01:28,  4.22s/it]

Loss: 0.014364558272063732


 64%|██████▍   | 36/56 [02:36<01:24,  4.23s/it]

Loss: 0.0009799021063372493


 66%|██████▌   | 37/56 [02:40<01:20,  4.23s/it]

Loss: 0.0009629664127714932


 68%|██████▊   | 38/56 [02:45<01:18,  4.35s/it]

Loss: 0.011216456070542336


 70%|██████▉   | 39/56 [02:49<01:11,  4.23s/it]

Loss: 0.0009205696987919509


 71%|███████▏  | 40/56 [02:53<01:06,  4.17s/it]

Loss: 0.0009179767221212387


 73%|███████▎  | 41/56 [02:58<01:04,  4.31s/it]

Loss: 0.0008981867576949298


 75%|███████▌  | 42/56 [03:02<01:00,  4.32s/it]

Loss: 0.0008916306542232633


 77%|███████▋  | 43/56 [03:06<00:56,  4.35s/it]

Loss: 0.000856278114952147


 79%|███████▊  | 44/56 [03:11<00:51,  4.32s/it]

Loss: 0.0008504975703544915


 80%|████████  | 45/56 [03:15<00:46,  4.26s/it]

Loss: 0.02284344471991062


 82%|████████▏ | 46/56 [03:19<00:42,  4.29s/it]

Loss: 0.0008410200825892389


 84%|████████▍ | 47/56 [03:23<00:38,  4.23s/it]

Loss: 0.0008321671048179269


 86%|████████▌ | 48/56 [03:28<00:34,  4.30s/it]

Loss: 0.0008148580673150718


 88%|████████▊ | 49/56 [03:32<00:30,  4.31s/it]

Loss: 0.0007976974593475461


 89%|████████▉ | 50/56 [03:36<00:26,  4.34s/it]

Loss: 0.0007981930393725634


 91%|█████████ | 51/56 [03:41<00:22,  4.52s/it]

Loss: 0.0007837067823857069


 93%|█████████▎| 52/56 [03:46<00:18,  4.51s/it]

Loss: 0.01241996232420206


 95%|█████████▍| 53/56 [03:50<00:13,  4.40s/it]

Loss: 0.0007547592977061868


 96%|█████████▋| 54/56 [03:54<00:08,  4.26s/it]

Loss: 0.0007609069580212235


 98%|█████████▊| 55/56 [03:59<00:04,  4.36s/it]

Loss: 0.012842394411563873


100%|██████████| 56/56 [04:00<00:00,  4.29s/it]

Loss: 0.0007476924220100045



  _warn_prf(average, modifier, msg_start, len(result))


Epoch 0: {'model_precision': 0.0016, 'model_recall': 0.058823529411764705, 'model_f1': 0.0031152647975077885, 'model_number': 17, 'other_precision': 0.9995507637017071, 'other_recall': 0.9809594665344025, 'other_f1': 0.9901678557025048, 'other_number': 36291, 'price_precision': 0.0, 'price_recall': 0.0, 'price_f1': 0.0, 'price_number': 0, 'overall_precision': 0.9805277073923102, 'overall_recall': 0.9805277073923102, 'overall_f1': 0.9805277073923102, 'overall_accuracy': 0.9805277073923102}


  2%|▏         | 1/56 [00:04<04:01,  4.39s/it]

Loss: 0.0007401028415188193


  4%|▎         | 2/56 [00:08<03:59,  4.43s/it]

Loss: 0.02660837396979332


  5%|▌         | 3/56 [00:13<03:50,  4.34s/it]

Loss: 0.0007318622665479779


  7%|▋         | 4/56 [00:17<03:41,  4.27s/it]

Loss: 0.0007336796843446791


  9%|▉         | 5/56 [00:21<03:30,  4.13s/it]

Loss: 0.0007313460810109973


 11%|█         | 6/56 [00:25<03:34,  4.30s/it]

Loss: 0.0007408285746350884


 12%|█▎        | 7/56 [00:30<03:33,  4.36s/it]

Loss: 0.0007221372798085213


 14%|█▍        | 8/56 [00:34<03:26,  4.30s/it]

Loss: 0.0007189595489762723


 16%|█▌        | 9/56 [00:38<03:20,  4.27s/it]

Loss: 0.0007218042155727744


 18%|█▊        | 10/56 [00:42<03:17,  4.30s/it]

Loss: 0.0007183073903433979


 20%|█▉        | 11/56 [00:47<03:15,  4.35s/it]

Loss: 0.000703335739672184


 21%|██▏       | 12/56 [00:51<03:11,  4.36s/it]

Loss: 0.0007047029794193804


 23%|██▎       | 13/56 [00:56<03:11,  4.46s/it]

Loss: 0.0006880658329464495


 25%|██▌       | 14/56 [01:00<03:06,  4.45s/it]

Loss: 0.013135089538991451


 27%|██▋       | 15/56 [01:05<03:04,  4.50s/it]

Loss: 0.0006747983279637992


 29%|██▊       | 16/56 [01:10<03:05,  4.64s/it]

Loss: 0.0006730343447998166


 30%|███       | 17/56 [01:14<02:58,  4.57s/it]

Loss: 0.028256090357899666


 32%|███▏      | 18/56 [01:19<02:54,  4.59s/it]

Loss: 0.000673745118547231


 34%|███▍      | 19/56 [01:23<02:47,  4.52s/it]

Loss: 0.012446822598576546


 36%|███▌      | 20/56 [01:27<02:36,  4.35s/it]

Loss: 0.010487130843102932


 38%|███▊      | 21/56 [01:32<02:34,  4.43s/it]

Loss: 0.0006937668658792973


 39%|███▉      | 22/56 [01:37<02:32,  4.50s/it]

Loss: 0.0006970829563215375


 41%|████      | 23/56 [01:41<02:24,  4.38s/it]

Loss: 0.0007073474698700011


 43%|████▎     | 24/56 [01:45<02:19,  4.37s/it]

Loss: 0.02263972908258438


 45%|████▍     | 25/56 [01:49<02:10,  4.22s/it]

Loss: 0.0007160734385251999


 46%|████▋     | 26/56 [01:53<02:08,  4.29s/it]

Loss: 0.011048855260014534


 48%|████▊     | 27/56 [01:58<02:04,  4.29s/it]

Loss: 0.000751671614125371


 50%|█████     | 28/56 [02:02<02:01,  4.33s/it]

Loss: 0.0007599959499202669


 52%|█████▏    | 29/56 [02:07<01:57,  4.37s/it]

Loss: 0.0007675751694478095


 54%|█████▎    | 30/56 [02:11<01:52,  4.33s/it]

Loss: 0.0007749589858576655


 55%|█████▌    | 31/56 [02:15<01:49,  4.39s/it]

Loss: 0.0007566810236312449


 57%|█████▋    | 32/56 [02:20<01:46,  4.43s/it]

Loss: 0.0007719001150690019


 59%|█████▉    | 33/56 [02:24<01:41,  4.43s/it]

Loss: 0.0007645281730219722


 61%|██████    | 34/56 [02:29<01:38,  4.49s/it]

Loss: 0.0007542093517258763


 62%|██████▎   | 35/56 [02:33<01:33,  4.44s/it]

Loss: 0.0007625404978170991


 64%|██████▍   | 36/56 [02:38<01:29,  4.47s/it]

Loss: 0.0007331559318117797


 66%|██████▌   | 37/56 [02:42<01:24,  4.43s/it]

Loss: 0.00988280214369297


 68%|██████▊   | 38/56 [02:47<01:21,  4.53s/it]

Loss: 0.0007391198305413127


 70%|██████▉   | 39/56 [02:51<01:16,  4.50s/it]

Loss: 0.01017853058874607


 71%|███████▏  | 40/56 [02:56<01:12,  4.50s/it]

Loss: 0.0006862420705147088


 73%|███████▎  | 41/56 [03:00<01:06,  4.45s/it]

Loss: 0.0006899258587509394


 75%|███████▌  | 42/56 [03:05<01:02,  4.48s/it]

Loss: 0.0006890463409945369


 77%|███████▋  | 43/56 [03:09<00:57,  4.44s/it]

Loss: 0.0006800286937505007


 79%|███████▊  | 44/56 [03:13<00:52,  4.40s/it]

Loss: 0.02203630656003952


 80%|████████  | 45/56 [03:18<00:47,  4.33s/it]

Loss: 0.0006671219016425312


 82%|████████▏ | 46/56 [03:22<00:43,  4.33s/it]

Loss: 0.0006723203114233911


 84%|████████▍ | 47/56 [03:27<00:39,  4.44s/it]

Loss: 0.000659262528643012


 86%|████████▌ | 48/56 [03:31<00:36,  4.58s/it]

Loss: 0.0006675625336356461


 88%|████████▊ | 49/56 [03:36<00:31,  4.47s/it]

Loss: 0.009964758530259132


 89%|████████▉ | 50/56 [03:40<00:26,  4.43s/it]

Loss: 0.0006617752369493246


 91%|█████████ | 51/56 [03:44<00:21,  4.38s/it]

Loss: 0.0006375390803441405


 93%|█████████▎| 52/56 [03:49<00:17,  4.42s/it]

Loss: 0.0006375082302838564


 95%|█████████▍| 53/56 [03:53<00:13,  4.40s/it]

Loss: 0.0006500166491605341


 96%|█████████▋| 54/56 [03:58<00:09,  4.50s/it]

Loss: 0.000652586983051151


 98%|█████████▊| 55/56 [04:02<00:04,  4.51s/it]

Loss: 0.014281469397246838


100%|██████████| 56/56 [04:04<00:00,  4.36s/it]

Loss: 0.0006261350936256349



  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1: {'model_precision': 0.0, 'model_recall': 0.0, 'model_f1': 0.0, 'model_number': 17, 'other_precision': 0.9995317836289523, 'other_recall': 1.0, 'other_f1': 0.999765836995, 'other_number': 36291, 'overall_precision': 0.9995317836289523, 'overall_recall': 0.9995317836289523, 'overall_f1': 0.9995317836289523, 'overall_accuracy': 0.9995317836289523}


  2%|▏         | 1/56 [00:04<04:06,  4.48s/it]

Loss: 0.0006213058368302882


  4%|▎         | 2/56 [00:08<03:56,  4.37s/it]

Loss: 0.0006142458296380937


  5%|▌         | 3/56 [00:13<04:06,  4.65s/it]

Loss: 0.0006164807709865272


  7%|▋         | 4/56 [00:17<03:51,  4.46s/it]

Loss: 0.0006103532505221665


  9%|▉         | 5/56 [00:22<03:44,  4.40s/it]

Loss: 0.0006100016762502491


 11%|█         | 6/56 [00:26<03:36,  4.32s/it]

Loss: 0.0005877892835997045


 12%|█▎        | 7/56 [00:31<03:41,  4.53s/it]

Loss: 0.000574532023165375


 14%|█▍        | 8/56 [00:35<03:38,  4.56s/it]

Loss: 0.00056190334726125


 16%|█▌        | 9/56 [00:40<03:29,  4.45s/it]

Loss: 0.0005500447587110102


 18%|█▊        | 10/56 [00:44<03:22,  4.39s/it]

Loss: 0.011597414501011372


 20%|█▉        | 11/56 [00:48<03:11,  4.26s/it]

Loss: 0.02078929729759693


 21%|██▏       | 12/56 [00:52<03:06,  4.23s/it]

Loss: 0.0005459294188767672


 23%|██▎       | 13/56 [00:56<03:03,  4.27s/it]

Loss: 0.011851036921143532


 25%|██▌       | 14/56 [01:01<03:01,  4.33s/it]

Loss: 0.0005583949387073517


 27%|██▋       | 15/56 [01:06<03:08,  4.59s/it]

Loss: 0.017645401880145073


 29%|██▊       | 16/56 [01:10<03:01,  4.53s/it]

Loss: 0.0005911862244829535


 30%|███       | 17/56 [01:15<02:53,  4.45s/it]

Loss: 0.0006068266811780632


 32%|███▏      | 18/56 [01:19<02:47,  4.41s/it]

Loss: 0.012951911427080631


 34%|███▍      | 19/56 [01:23<02:39,  4.30s/it]

Loss: 0.0006047817296348512


 36%|███▌      | 20/56 [01:28<02:36,  4.34s/it]

Loss: 0.0006295157945714891


 38%|███▊      | 21/56 [01:32<02:32,  4.34s/it]

Loss: 0.0006328225717879832


 39%|███▉      | 22/56 [01:36<02:24,  4.25s/it]

Loss: 0.011133339256048203


 41%|████      | 23/56 [01:40<02:21,  4.28s/it]

Loss: 0.013190644793212414


 43%|████▎     | 24/56 [01:45<02:17,  4.29s/it]

Loss: 0.0006531666149385273


 45%|████▍     | 25/56 [01:50<02:19,  4.51s/it]

Loss: 0.0006989517132751644


 46%|████▋     | 26/56 [01:54<02:17,  4.57s/it]

Loss: 0.012599463574588299


 48%|████▊     | 27/56 [01:59<02:11,  4.54s/it]

Loss: 0.0154902134090662


 50%|█████     | 28/56 [02:03<02:06,  4.53s/it]

Loss: 0.0007883156649768353


 52%|█████▏    | 29/56 [02:07<01:57,  4.36s/it]

Loss: 0.0007689300109632313


 54%|█████▎    | 30/56 [02:12<01:55,  4.43s/it]

Loss: 0.0008010881720110774


 55%|█████▌    | 31/56 [02:17<01:53,  4.56s/it]

Loss: 0.0007587409927509725


 57%|█████▋    | 32/56 [02:21<01:45,  4.42s/it]

Loss: 0.00077798095298931


 59%|█████▉    | 33/56 [02:26<01:45,  4.59s/it]

Loss: 0.0008357032784260809


 61%|██████    | 34/56 [02:30<01:38,  4.48s/it]

Loss: 0.0008044550777412951


 62%|██████▎   | 35/56 [02:35<01:36,  4.57s/it]

Loss: 0.010772213339805603


 64%|██████▍   | 36/56 [02:39<01:31,  4.58s/it]

Loss: 0.0007780682062730193


 66%|██████▌   | 37/56 [02:44<01:25,  4.48s/it]

Loss: 0.008270090445876122


 68%|██████▊   | 38/56 [02:48<01:20,  4.49s/it]

Loss: 0.0007551565067842603


 70%|██████▉   | 39/56 [02:53<01:16,  4.51s/it]

Loss: 0.012381776235997677


 71%|███████▏  | 40/56 [02:57<01:12,  4.53s/it]

Loss: 0.0007020584307610989


 73%|███████▎  | 41/56 [03:02<01:08,  4.56s/it]

Loss: 0.0006857584230601788


 75%|███████▌  | 42/56 [03:06<01:02,  4.49s/it]

Loss: 0.0007267780019901693


 77%|███████▋  | 43/56 [03:10<00:56,  4.36s/it]

Loss: 0.0007176480139605701


 79%|███████▊  | 44/56 [03:15<00:52,  4.40s/it]

Loss: 0.0006524738855659962


 80%|████████  | 45/56 [03:19<00:47,  4.35s/it]

Loss: 0.0006728587322868407


 82%|████████▏ | 46/56 [03:24<00:46,  4.63s/it]

Loss: 0.0006076169665902853


 84%|████████▍ | 47/56 [03:29<00:40,  4.52s/it]

Loss: 0.009162676520645618


 86%|████████▌ | 48/56 [03:33<00:36,  4.61s/it]

Loss: 0.0006300300592556596


 88%|████████▊ | 49/56 [03:38<00:31,  4.52s/it]

Loss: 0.018627068027853966


 89%|████████▉ | 50/56 [03:42<00:27,  4.56s/it]

Loss: 0.012170800939202309


 91%|█████████ | 51/56 [03:47<00:22,  4.48s/it]

Loss: 0.0006306135910563171


 93%|█████████▎| 52/56 [03:51<00:17,  4.46s/it]

Loss: 0.0006734436028636992


 95%|█████████▍| 53/56 [03:56<00:13,  4.46s/it]

Loss: 0.0006384599255397916


 96%|█████████▋| 54/56 [04:00<00:08,  4.41s/it]

Loss: 0.0006479269941337407


 98%|█████████▊| 55/56 [04:04<00:04,  4.46s/it]

Loss: 0.0006707610446028411


100%|██████████| 56/56 [04:06<00:00,  4.40s/it]

Loss: 0.0005899178795516491





Epoch 2: {'model_precision': 0.0, 'model_recall': 0.0, 'model_f1': 0.0, 'model_number': 17, 'other_precision': 0.9995317836289523, 'other_recall': 1.0, 'other_f1': 0.999765836995, 'other_number': 36291, 'overall_precision': 0.9995317836289523, 'overall_recall': 0.9995317836289523, 'overall_f1': 0.9995317836289523, 'overall_accuracy': 0.9995317836289523}


  0%|          | 0/56 [00:00<?, ?it/s]

Loss: 0.0006325311260297894


  4%|▎         | 2/56 [00:09<04:14,  4.72s/it]

Loss: 0.0006269156001508236


  5%|▌         | 3/56 [00:14<04:10,  4.72s/it]

Loss: 0.0005982942529954016


  7%|▋         | 4/56 [00:19<04:06,  4.74s/it]

Loss: 0.0005722601781599224


  9%|▉         | 5/56 [00:23<03:53,  4.58s/it]

Loss: 0.0005770449643023312


 11%|█         | 6/56 [00:27<03:45,  4.51s/it]

Loss: 0.0005589693319052458


 12%|█▎        | 7/56 [00:31<03:37,  4.43s/it]

Loss: 0.009582123719155788


 14%|█▍        | 8/56 [00:36<03:29,  4.37s/it]

Loss: 0.0005147166084498167


 16%|█▌        | 9/56 [00:40<03:27,  4.42s/it]

Loss: 0.0005462339031510055


 18%|█▊        | 10/56 [00:44<03:19,  4.34s/it]

Loss: 0.0005081075360067189


 20%|█▉        | 11/56 [00:49<03:17,  4.39s/it]

Loss: 0.0005105319432914257


 21%|██▏       | 12/56 [00:53<03:08,  4.28s/it]

Loss: 0.010111037641763687


 23%|██▎       | 13/56 [00:58<03:09,  4.41s/it]

Loss: 0.00046233745524659753


 25%|██▌       | 14/56 [01:02<03:07,  4.45s/it]

Loss: 0.0004933346644975245


 27%|██▋       | 15/56 [01:07<03:01,  4.41s/it]

Loss: 0.012094075791537762


 29%|██▊       | 16/56 [01:11<02:58,  4.45s/it]

Loss: 0.00046861631562933326


 30%|███       | 17/56 [01:16<02:55,  4.49s/it]

Loss: 0.0004530574078671634


 32%|███▏      | 18/56 [01:20<02:54,  4.59s/it]

Loss: 0.0004708210180979222


 34%|███▍      | 19/56 [01:25<02:50,  4.61s/it]

Loss: 0.00045091635547578335


 36%|███▌      | 20/56 [01:30<02:44,  4.58s/it]

Loss: 0.00046383318840526044


 38%|███▊      | 21/56 [01:34<02:39,  4.55s/it]

Loss: 0.02275168150663376


 39%|███▉      | 22/56 [01:39<02:34,  4.56s/it]

Loss: 0.013009482994675636


 41%|████      | 23/56 [01:43<02:28,  4.49s/it]

Loss: 0.0004514413303695619


 43%|████▎     | 24/56 [01:47<02:23,  4.48s/it]

Loss: 0.0004992696340195835


 45%|████▍     | 25/56 [01:52<02:17,  4.44s/it]

Loss: 0.011375919915735722


 46%|████▋     | 26/56 [01:56<02:14,  4.50s/it]

Loss: 0.010062162764370441


 48%|████▊     | 27/56 [02:01<02:10,  4.50s/it]

Loss: 0.010836540721356869


 50%|█████     | 28/56 [02:06<02:08,  4.58s/it]

Loss: 0.000673549366183579


 52%|█████▏    | 29/56 [02:11<02:05,  4.65s/it]

Loss: 0.0006770451436750591


 54%|█████▎    | 30/56 [02:15<01:56,  4.49s/it]

Loss: 0.0006842556176707149


 55%|█████▌    | 31/56 [02:19<01:53,  4.53s/it]

Loss: 0.0007680142880417407


 57%|█████▋    | 32/56 [02:23<01:46,  4.44s/it]

Loss: 0.0006546656368300319


 59%|█████▉    | 33/56 [02:28<01:43,  4.49s/it]

Loss: 0.0007017585448920727


 61%|██████    | 34/56 [02:33<01:38,  4.49s/it]

Loss: 0.0008094875956885517


 62%|██████▎   | 35/56 [02:37<01:31,  4.37s/it]

Loss: 0.009007353335618973


 64%|██████▍   | 36/56 [02:42<01:30,  4.51s/it]

Loss: 0.0006743145640939474


 66%|██████▌   | 37/56 [02:45<01:20,  4.26s/it]

Loss: 0.0006838382687419653


 68%|██████▊   | 38/56 [02:50<01:17,  4.28s/it]

Loss: 0.00869715679436922


 70%|██████▉   | 39/56 [02:54<01:12,  4.25s/it]

Loss: 0.0006805082666687667


 71%|███████▏  | 40/56 [02:58<01:09,  4.35s/it]

Loss: 0.0006777290254831314


 73%|███████▎  | 41/56 [03:03<01:06,  4.47s/it]

Loss: 0.0007296146359294653


 75%|███████▌  | 42/56 [03:08<01:02,  4.48s/it]

Loss: 0.007777153514325619


 77%|███████▋  | 43/56 [03:12<00:58,  4.47s/it]

Loss: 0.018410641700029373


 79%|███████▊  | 44/56 [03:17<00:54,  4.52s/it]

Loss: 0.0006255790358409286


 80%|████████  | 45/56 [03:21<00:48,  4.42s/it]

Loss: 0.000680184515658766


 82%|████████▏ | 46/56 [03:25<00:43,  4.31s/it]

Loss: 0.0007039750926196575


 84%|████████▍ | 47/56 [03:29<00:38,  4.31s/it]

Loss: 0.0006623574299737811


 86%|████████▌ | 48/56 [03:33<00:34,  4.30s/it]

Loss: 0.0006862875889055431


 88%|████████▊ | 49/56 [03:38<00:31,  4.44s/it]

Loss: 0.018835105001926422


 89%|████████▉ | 50/56 [03:43<00:26,  4.40s/it]

Loss: 0.0006568903336301446


 91%|█████████ | 51/56 [03:47<00:22,  4.44s/it]

Loss: 0.0006861735018901527


 93%|█████████▎| 52/56 [03:51<00:17,  4.41s/it]

Loss: 0.0007000619661994278


 95%|█████████▍| 53/56 [03:55<00:12,  4.23s/it]

Loss: 0.00849150586873293


 96%|█████████▋| 54/56 [03:59<00:08,  4.22s/it]

Loss: 0.000648479734081775


 98%|█████████▊| 55/56 [04:04<00:04,  4.18s/it]

Loss: 0.0006899075815454125


100%|██████████| 56/56 [04:05<00:00,  4.38s/it]

Loss: 0.0006625596433877945





Epoch 3: {'model_precision': 0.0, 'model_recall': 0.0, 'model_f1': 0.0, 'model_number': 17, 'other_precision': 0.9995317836289523, 'other_recall': 1.0, 'other_f1': 0.999765836995, 'other_number': 36291, 'overall_precision': 0.9995317836289523, 'overall_recall': 0.9995317836289523, 'overall_f1': 0.9995317836289523, 'overall_accuracy': 0.9995317836289523}


  2%|▏         | 1/56 [00:04<04:29,  4.89s/it]

Loss: 0.010747256688773632


  4%|▎         | 2/56 [00:09<04:01,  4.48s/it]

Loss: 0.0005693078273907304


  5%|▌         | 3/56 [00:13<04:01,  4.55s/it]

Loss: 0.009867382235825062


  7%|▋         | 4/56 [00:17<03:50,  4.44s/it]

Loss: 0.000591165735386312


  9%|▉         | 5/56 [00:22<03:43,  4.37s/it]

Loss: 0.0006876355619169772


 11%|█         | 6/56 [00:26<03:34,  4.29s/it]

Loss: 0.0007622696575708687


 12%|█▎        | 7/56 [00:30<03:28,  4.25s/it]

Loss: 0.0005793525488115847


 14%|█▍        | 8/56 [00:34<03:22,  4.21s/it]

Loss: 0.0007324846810661256


 16%|█▌        | 9/56 [00:38<03:18,  4.23s/it]

Loss: 0.005652423482388258


 18%|█▊        | 10/56 [00:44<03:26,  4.50s/it]

Loss: 0.000619021593593061


 20%|█▉        | 11/56 [00:48<03:17,  4.38s/it]

Loss: 0.0005172462551854551


 21%|██▏       | 12/56 [00:52<03:16,  4.47s/it]

Loss: 0.0005919271497987211


 23%|██▎       | 13/56 [00:56<03:05,  4.32s/it]

Loss: 0.007093268912285566


 25%|██▌       | 14/56 [01:01<03:06,  4.43s/it]

Loss: 0.0005187986535020173


 27%|██▋       | 15/56 [01:06<03:04,  4.49s/it]

Loss: 0.007911707274615765


 29%|██▊       | 16/56 [01:09<02:51,  4.29s/it]

Loss: 0.006548676174134016


 30%|███       | 17/56 [01:14<02:47,  4.30s/it]

Loss: 0.013257094658911228


 32%|███▏      | 18/56 [01:18<02:45,  4.34s/it]

Loss: 0.000569673371501267


 34%|███▍      | 19/56 [01:23<02:43,  4.42s/it]

Loss: 0.0007577237556688488


 36%|███▌      | 20/56 [01:28<02:44,  4.56s/it]

Loss: 0.000846448412630707


 38%|███▊      | 21/56 [01:34<02:56,  5.04s/it]

Loss: 0.0007311733206734061


 39%|███▉      | 22/56 [01:40<02:59,  5.29s/it]

Loss: 0.0008377482881769538


 41%|████      | 23/56 [01:46<03:01,  5.50s/it]

Loss: 0.0007303683669306338


 43%|████▎     | 24/56 [01:51<02:55,  5.47s/it]

Loss: 0.0008884831913746893


 45%|████▍     | 25/56 [01:58<03:00,  5.81s/it]

Loss: 0.0007220946718007326


 46%|████▋     | 26/56 [02:03<02:53,  5.77s/it]

Loss: 0.0006365891313180327


 48%|████▊     | 27/56 [02:10<02:50,  5.87s/it]

Loss: 0.0007056114845909178


 50%|█████     | 28/56 [02:16<02:47,  5.98s/it]

Loss: 0.0005633183754980564


 52%|█████▏    | 29/56 [02:22<02:41,  5.98s/it]

Loss: 0.00048047254676930606


 54%|█████▎    | 30/56 [02:28<02:34,  5.94s/it]

Loss: 0.0005078922840766609


 55%|█████▌    | 31/56 [02:33<02:28,  5.93s/it]

Loss: 0.0004774780827574432


 57%|█████▋    | 32/56 [02:39<02:18,  5.77s/it]

Loss: 0.0004727728955913335


 59%|█████▉    | 33/56 [02:45<02:16,  5.92s/it]

Loss: 0.00040530675323680043


 61%|██████    | 34/56 [02:51<02:12,  6.04s/it]

Loss: 0.01576053537428379


 62%|██████▎   | 35/56 [02:58<02:08,  6.13s/it]

Loss: 0.00042966107139363885


 64%|██████▍   | 36/56 [03:04<02:01,  6.06s/it]

Loss: 0.0004086526168975979


 66%|██████▌   | 37/56 [03:10<01:57,  6.17s/it]

Loss: 0.0003697772917803377


 68%|██████▊   | 38/56 [03:16<01:48,  6.01s/it]

Loss: 0.0005298279575072229


 70%|██████▉   | 39/56 [03:22<01:41,  5.98s/it]

Loss: 0.011495450511574745


 71%|███████▏  | 40/56 [03:27<01:33,  5.83s/it]

Loss: 0.007127890828996897


 73%|███████▎  | 41/56 [03:33<01:29,  5.98s/it]

Loss: 0.00047191540943458676


 75%|███████▌  | 42/56 [03:39<01:22,  5.91s/it]

Loss: 0.000863094290252775


 77%|███████▋  | 43/56 [03:45<01:17,  5.95s/it]

Loss: 0.004289249423891306


 79%|███████▊  | 44/56 [03:51<01:11,  5.97s/it]

Loss: 0.0016353782266378403


 80%|████████  | 45/56 [03:58<01:06,  6.06s/it]

Loss: 0.0012870724312961102


 82%|████████▏ | 46/56 [04:04<01:00,  6.09s/it]

Loss: 0.001036743400618434


 84%|████████▍ | 47/56 [04:10<00:56,  6.27s/it]

Loss: 0.005204011220484972


 86%|████████▌ | 48/56 [04:16<00:49,  6.15s/it]

Loss: 0.0008123539155349135


 88%|████████▊ | 49/56 [04:23<00:43,  6.19s/it]

Loss: 0.0045530591160058975


 89%|████████▉ | 50/56 [04:29<00:37,  6.18s/it]

Loss: 0.0009942019823938608


 91%|█████████ | 51/56 [04:35<00:31,  6.23s/it]

Loss: 0.000421932025346905


 93%|█████████▎| 52/56 [04:41<00:24,  6.15s/it]

Loss: 0.0011981988791376352


 95%|█████████▍| 53/56 [04:47<00:18,  6.08s/it]

Loss: 0.0007334440597333014


 96%|█████████▋| 54/56 [04:53<00:12,  6.09s/it]

Loss: 0.004235835745930672


 98%|█████████▊| 55/56 [04:59<00:06,  6.12s/it]

Loss: 0.0007193561759777367


100%|██████████| 56/56 [05:01<00:00,  5.39s/it]

Loss: 0.000334622134687379





Epoch 4: {'model_precision': 0.0, 'model_recall': 0.0, 'model_f1': 0.0, 'model_number': 17, 'other_precision': 0.9995317836289523, 'other_recall': 1.0, 'other_f1': 0.999765836995, 'other_number': 36291, 'overall_precision': 0.9995317836289523, 'overall_recall': 0.9995317836289523, 'overall_f1': 0.9995317836289523, 'overall_accuracy': 0.9995317836289523}


In [None]:
metric = evaluate.load("seqeval")

model.eval()
for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)


    predictions = outputs.logits.argmax(dim=-1)
    labels = batch["labels"]

    preds, refs = get_labels(predictions, labels)

    metric.add_batch(
        predictions=preds,
        references=refs,
    )


eval_metrics = metric.compute()

In [14]:
eval_metrics

{'manufacturer': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 100},
 'model': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'number': 77},
 'other': {'precision': 0.9946530525934205,
  'recall': 1.0,
  'f1': 0.9973193596752942,
  'number': 32926},
 'overall_precision': 0.9946530525934205,
 'overall_recall': 0.9946530525934205,
 'overall_f1': 0.9946530525934205,
 'overall_accuracy': 0.9946530525934205}

In [None]:
# res = {}

# for x in all_refs:
#     for y in x:
#         y = int(y)
#         if y in res:
#             res[y] += 1
#         else:
#             res[y] = 1
# print(res)

In [None]:
# from collections import Counter

# count = {
#     0: 0,
#     1: 0,
#     2: 0,
#     3: 0
# }

# for example in valid_data:
#     example_cnt = Counter(example["node_labels"][0])
#     # print(example_cnt)
#     for key in count.keys():
#         count[key] += example_cnt[key]

# print(count)

## Inference

Let's try out the model on a new web page for which we have the nodes and xpaths. Here we'll just use one of our training set.

In [None]:
nodes = data[0]['nodes']
xpaths = data[0]['xpaths']
node_labels = data[0]['node_labels']
print("Nodes:", nodes)
print("Xpaths:", xpaths)

We'll prepare the example for the model using the processor. Note that we're passing `return_offsets_mapping=True`, as the offsets allow us to determine which tokens are at the start of a given word at which aren't.

In [None]:
# prepare for model
# note that you don't need to prepare node_labels, we just have them available here so we'll compare to the ground truth
# print(processor.max_length)
encoding = processor(nodes=nodes, xpaths=xpaths, node_labels=node_labels, truncation=True, max_length=max_length, return_offsets_mapping=True, return_tensors="pt").to(device)
for k,v in encoding.items():
  print(k,v.shape)

Let's perform a forward pass:

In [None]:
# we don't need the offset mapping and labels for the forward pass
offset_mapping = encoding.pop("offset_mapping")
labels = encoding.pop("labels")

# forward pass
with torch.no_grad():
  outputs = model(**encoding)

The model outputs logits of shape (batch_size, seq_len, num_labels). We just take the highest logit (score) per token as prediction:

In [None]:
predictions = outputs.logits.argmax(-1)
print(predictions)

The model makes predictions at the token level, however we're only interested in the predicted label for the first token of each node.

This can be achieved by accessing the word_ids (to know whether or not the token is a special token or not) and the offset_mapping (to know whether or not the token is the first of a particular node).

In [None]:
results = {"Node": [], "Predicted": [], "Ground truth": []}

for pred_id, word_id, offset, label_id in zip(predictions[0].tolist(), encoding.word_ids(0), offset_mapping[0].tolist(), labels[0].tolist()):
  if word_id is not None and offset[0] == 0:
    # print(f"Node: {nodes[0][word_id]}")
    # print(f"Predicted: {id2label[pred_id]}")
    # print(f"Ground truth: {id2label[label_id]}")
    # print("----------")
    results["Node"].append(nodes[0][word_id])
    results["Predicted"].append(id2label[pred_id])
    results["Ground truth"].append(id2label[label_id])

Let's pretty print the results as a Pandas dataframe:

In [None]:
import pandas as pd

pd.DataFrame.from_dict(results).head()