<a href="https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/LiLT/Fine_tune_LiltForTokenClassification_on_FUNSD_(nielsr_funsd).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Set-up environment

In [1]:
!pip install -q git+https://github.com/huggingface/transformers.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[K     |████████████████████████████████| 163 kB 15.8 MB/s 
[K     |████████████████████████████████| 7.6 MB 63.2 MB/s 
[?25h  Building wheel for transformers (PEP 517) ... [?25l[?25hdone


In [2]:
!pip install -q datasets

[K     |████████████████████████████████| 441 kB 14.3 MB/s 
[K     |████████████████████████████████| 115 kB 69.1 MB/s 
[K     |████████████████████████████████| 212 kB 71.3 MB/s 
[K     |████████████████████████████████| 127 kB 63.2 MB/s 
[K     |████████████████████████████████| 115 kB 69.7 MB/s 
[?25h

## Load dataset

In [3]:
from datasets import load_dataset

dataset = load_dataset("nielsr/funsd")

Downloading builder script:   0%|          | 0.00/4.54k [00:00<?, ?B/s]

Downloading and preparing dataset funsd/funsd to /root/.cache/huggingface/datasets/nielsr___funsd/funsd/1.0.0/8b0472b536a2dcb975d59a4fb9d6fea4e6a1abe260b7fed6f75301e168cbe595...


Downloading data:   0%|          | 0.00/16.8M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Dataset funsd downloaded and prepared to /root/.cache/huggingface/datasets/nielsr___funsd/funsd/1.0.0/8b0472b536a2dcb975d59a4fb9d6fea4e6a1abe260b7fed6f75301e168cbe595. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'words', 'bboxes', 'ner_tags', 'image_path'],
        num_rows: 149
    })
    test: Dataset({
        features: ['id', 'words', 'bboxes', 'ner_tags', 'image_path'],
        num_rows: 50
    })
})

In [5]:
dataset["train"].features

{'id': Value(dtype='string', id=None),
 'words': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
 'bboxes': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'ner_tags': Sequence(feature=ClassLabel(names=['O', 'B-HEADER', 'I-HEADER', 'B-QUESTION', 'I-QUESTION', 'B-ANSWER', 'I-ANSWER'], id=None), length=-1, id=None),
 'image_path': Value(dtype='string', id=None)}

In [6]:
labels = dataset["train"].features['ner_tags'].feature.names
id2label = {id:label for id, label in enumerate(labels)}
label2id = {label:id for id, label in enumerate(labels)}
print(id2label)

{0: 'O', 1: 'B-HEADER', 2: 'I-HEADER', 3: 'B-QUESTION', 4: 'I-QUESTION', 5: 'B-ANSWER', 6: 'I-ANSWER'}


In [7]:
example = dataset["train"][0]
print(example["words"])
print(example["bboxes"])
print(example["ner_tags"])

['R&D', ':', 'Suggestion:', 'Date:', 'Licensee', 'Yes', 'No', '597005708', 'R&D', 'QUALITY', 'IMPROVEMENT', 'SUGGESTION/', 'SOLUTION', 'FORM', 'Name', '/', 'Phone', 'Ext.', ':', 'M.', 'Hamann', 'P.', 'Harper,', 'P.', 'Martinez', '9/', '3/', '92', 'R&D', 'Group:', 'J.', 'S.', 'Wigand', 'Supervisor', '/', 'Manager', 'Discontinue', 'coal', 'retention', 'analyses', 'on', 'licensee', 'submitted', 'product', 'samples', '(Note', ':', 'Coal', 'Retention', 'testing', 'is', 'not', 'performed', 'by', 'most', 'licensees.', 'Other', 'B&W', 'physical', 'measurements', 'as', 'ends', 'stability', 'and', 'inspection', 'for', 'soft', 'spots', 'in', 'ciparettes', 'are', 'thought', 'to', 'be', 'sufficient', 'measures', 'to', 'assure', 'cigarette', 'physical', 'integrity.', 'The', 'proposed', 'action', 'will', 'increase', 'laboratory', 'productivity', '.', ')', 'Suggested', 'Solutions', '(s)', ':', 'Delete', 'coal', 'retention', 'from', 'the', 'list', 'of', 'standard', 'analyses', 'performed', 'on', 'licen

## Transform dataset

In [8]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlmv3-base")

Downloading:   0%|          | 0.00/1.14k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

In [9]:
def prepare_examples(batch):
  encoding = tokenizer(batch["words"],
                        boxes=batch["bboxes"],
                        word_labels=batch["ner_tags"],
                        padding="max_length",
                        max_length=128,
                        truncation=True,
                        return_tensors="pt")
  
  return encoding

dataset.set_transform(prepare_examples)

In [10]:
example = dataset["train"][0]
print(example.keys())

dict_keys(['input_ids', 'attention_mask', 'bbox', 'labels'])


In [11]:
tokenizer.decode(example["input_ids"])

'<s> R&D : Suggestion: Date: Licensee Yes No 597005708 R&D QUALITY IMPROVEMENT SUGGESTION/ SOLUTION FORM Name / Phone Ext. : M. Hamann P. Harper, P. Martinez 9/ 3/ 92 R&D Group: J. S. Wigand Supervisor / Manager Discontinue coal retention analyses on licensee submitted product samples (Note : Coal Retention testing is not performed by most licensees. Other B&W physical measurements as ends stability and inspection for soft spots in ciparettes are thought to be sufficient measures to assure cigarette</s>'

In [12]:
for id, box, label in zip(example["input_ids"].tolist(),
                          example["bbox"].tolist(),
                          example["labels"].tolist()):
  if label != -100:
    print(tokenizer.decode([id]), box, id2label[label])
  else:
    print(tokenizer.decode([id]), box, label)

<s> [0, 0, 0, 0] -100
 R [383, 91, 493, 175] O
& [383, 91, 493, 175] -100
D [383, 91, 493, 175] -100
 : [287, 316, 295, 327] B-QUESTION
 Suggest [124, 355, 221, 370] B-QUESTION
ion [124, 355, 221, 370] -100
: [124, 355, 221, 370] -100
 Date [632, 268, 679, 282] B-QUESTION
: [632, 268, 679, 282] -100
 License [670, 309, 748, 323] B-ANSWER
e [670, 309, 748, 323] -100
 Yes [604, 605, 633, 619] B-QUESTION
 No [715, 603, 738, 617] B-QUESTION
 5 [688, 904, 841, 926] O
97 [688, 904, 841, 926] -100
005 [688, 904, 841, 926] -100
708 [688, 904, 841, 926] -100
 R [337, 203, 366, 214] B-HEADER
& [337, 203, 366, 214] -100
D [337, 203, 366, 214] -100
 QU [374, 203, 438, 216] I-HEADER
AL [374, 203, 438, 216] -100
ITY [374, 203, 438, 216] -100
 IM [447, 201, 548, 211] I-HEADER
PROV [447, 201, 548, 211] -100
EMENT [447, 201, 548, 211] -100
 S [335, 215, 425, 229] I-HEADER
UG [335, 215, 425, 229] -100
G [335, 215, 425, 229] -100
EST [335, 215, 425, 229] -100
ION [335, 215, 425, 229] -100
/ [335, 215, 42

## Create PyTorch Dataloaders

In [13]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset["train"], batch_size=2, shuffle=True)
test_dataloader = DataLoader(dataset["test"], batch_size=2, shuffle=True)

In [14]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  print(k,v.shape)

input_ids torch.Size([2, 128])
attention_mask torch.Size([2, 128])
bbox torch.Size([2, 128, 4])
labels torch.Size([2, 128])


## Load model

In [15]:
from transformers import AutoModelForTokenClassification

model = AutoModelForTokenClassification.from_pretrained("nielsr/lilt-roberta-en-base", id2label=id2label, label2id=label2id)

Downloading:   0%|          | 0.00/697 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/523M [00:00<?, ?B/s]

Some weights of LiltForTokenClassification were not initialized from the model checkpoint at nielsr/lilt-roberta-en-base and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Train!

In [16]:
!pip install -q evaluate seqeval

[?25l[K     |████▌                           | 10 kB 28.8 MB/s eta 0:00:01[K     |█████████                       | 20 kB 15.0 MB/s eta 0:00:01[K     |█████████████▌                  | 30 kB 20.2 MB/s eta 0:00:01[K     |██████████████████              | 40 kB 16.3 MB/s eta 0:00:01[K     |██████████████████████▌         | 51 kB 13.9 MB/s eta 0:00:01[K     |███████████████████████████     | 61 kB 16.2 MB/s eta 0:00:01[K     |███████████████████████████████▌| 71 kB 14.4 MB/s eta 0:00:01[K     |████████████████████████████████| 72 kB 1.6 MB/s 
[?25h[?25l[K     |███████▌                        | 10 kB 39.6 MB/s eta 0:00:01[K     |███████████████                 | 20 kB 47.2 MB/s eta 0:00:01[K     |██████████████████████▌         | 30 kB 54.1 MB/s eta 0:00:01[K     |██████████████████████████████  | 40 kB 58.9 MB/s eta 0:00:01[K     |████████████████████████████████| 43 kB 2.5 MB/s 
[?25h  Building wheel for seqeval (setup.py) ... [?25l[?25hdone


In [17]:
import evaluate
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Metric
metric = evaluate.load("seqeval")

def get_labels(predictions, references):
    # Transform predictions and references tensors to numpy arrays
    if device.type == "cpu":
        y_pred = predictions.detach().clone().numpy()
        y_true = references.detach().clone().numpy()
    else:
        y_pred = predictions.detach().cpu().clone().numpy()
        y_true = references.detach().cpu().clone().numpy()

    # Remove ignored index (special tokens)
    true_predictions = [
        [labels[p] for (p, l) in zip(pred, gold_label) if l != -100]
        for pred, gold_label in zip(y_pred, y_true)
    ]
    true_labels = [
        [labels[l] for (p, l) in zip(pred, gold_label) if l != -100]
        for pred, gold_label in zip(y_pred, y_true)
    ]
    return true_predictions, true_labels

Downloading builder script:   0%|          | 0.00/6.34k [00:00<?, ?B/s]

In [19]:
from torch.optim import AdamW
from tqdm.auto import tqdm

optimizer = AdamW(model.parameters(), lr=5e-5)

model.to(device)

for epoch in range(50):
  print("Epoch:", epoch+1)
  for idx, batch in enumerate(tqdm(train_dataloader)):
      # move batch to device
      batch = {k:v.to(device) for k,v in batch.items()}
      outputs = model(**batch)

      predictions = outputs.logits.argmax(-1)
      true_predictions, true_labels = get_labels(predictions, batch["labels"])
      metric.add_batch(references=true_labels, predictions=true_predictions)

      loss = outputs.loss

      if idx % 100 == 0:
        print("Loss:", loss.item())
        results = metric.compute()
        print("Overall f1:", results["overall_f1"])
        print("Overall precision:", results["overall_f1"])
        print("Overall recall:", results["overall_recall"])

      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

Epoch: 1


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.00019168781000189483
Overall f1: 0.9969788519637462
Overall precision: 0.9969788519637462
Overall recall: 1.0
Epoch: 2


  0%|          | 0/75 [00:00<?, ?it/s]

Loss: 0.0012669997522607446
Overall f1: 0.9875961299925576
Overall precision: 0.9875961299925576
Overall recall: 0.9883316782522343


KeyboardInterrupt: ignored

## Evaluate

In [20]:
from tqdm.auto import tqdm

eval_metric = evaluate.load("seqeval")

for idx, batch in enumerate(tqdm(test_dataloader)):
    # move batch to device
    batch = {k:v.to(device) for k,v in batch.items()}
    with torch.no_grad():
      outputs = model(**batch)

    predictions = outputs.logits.argmax(-1)
    true_predictions, true_labels = get_labels(predictions, batch["labels"])
    eval_metric.add_batch(references=true_labels, predictions=true_predictions)

  0%|          | 0/25 [00:00<?, ?it/s]

In [21]:
results = eval_metric.compute()
results["overall_f1"]

0.7721859393008068

In [22]:
results["overall_accuracy"]

0.7934749153585718