In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import transformers
from transformers import LayoutLMForTokenClassification,\
    LayoutLMTokenizer, AdamW, LayoutLMv2Processor, LayoutLMv2ForTokenClassification
# from tensordict import TensorDict
from datasets import load_dataset
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import ImageDraw, Image
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
import logging
# from torchvision.transforms import PILToTensor

import os 
if 'notebooks' in os.getcwd():
    os.chdir("..")
    
from src.preprocessing.make_dataset import ImageLayoutDataset

In [4]:
logging.basicConfig(filename='logs/train.log', encoding='utf-8', level=logging.DEBUG)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
device

device(type='cuda')

In [7]:
dataset = load_dataset("katanaml/cord", )

# dataset = load_dataset("darentang/sroie")

Repo card metadata block was not found. Setting CardData to empty.


## Creating PyTorch Datasets, DataLoader

In [20]:
processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")


In [21]:
data = []
for example in tqdm(dataset['train']):
    words = example['words']
    boxes = example['bboxes']
    image = Image.open(example['image_path']).convert("RGB")
    word_labels = example['ner_tags']

    # try:
    encoded_inputs = processor(
        image, 
        words, 
        boxes=boxes, 
        word_labels=word_labels, 
        padding="max_length", 
        truncation=True, 
        return_tensors="pt"
    )

    
    data.append(encoded_inputs)
    # except:
    #     pass

    # assert encoded_inputs.input_ids.shape == torch.Size([512])
    # assert encoded_inputs.attention_mask.shape == torch.Size([512])
    # assert encoded_inputs.token_type_ids.shape == torch.Size([512])
    # assert encoded_inputs.bbox.shape == torch.Size([512, 4])
    # assert encoded_inputs.image.shape == torch.Size([3, 224, 224])
    # assert encoded_inputs.labels.shape == torch.Size([512]) 

  0%|          | 0/800 [00:00<?, ?it/s]

100%|██████████| 800/800 [00:38<00:00, 20.53it/s]


In [9]:
df = ImageLayoutDataset(dataset['train'], encode=True)

100%|██████████| 800/800 [00:37<00:00, 21.05it/s]


In [10]:
dataloader = DataLoader(
    df,
    shuffle=True,
    batch_size= 2
)

In [11]:
unique_rows = []
for row in dataset['train']['ner_tags']:
    unique_rows.append(np.unique(row))

In [12]:
n_labels = np.unique(np.concatenate(unique_rows)).shape[0]

In [13]:
n_labels

23

## Importing model

In [14]:

model = LayoutLMForTokenClassification.from_pretrained(
    'microsoft/layoutlm-base-uncased',
    num_labels=n_labels
)
model.to(device)

Some weights of LayoutLMForTokenClassification were not initialized from the model checkpoint at microsoft/layoutlm-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


LayoutLMForTokenClassification(
  (layoutlm): LayoutLMModel(
    (embeddings): LayoutLMEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (x_position_embeddings): Embedding(1024, 768)
      (y_position_embeddings): Embedding(1024, 768)
      (h_position_embeddings): Embedding(1024, 768)
      (w_position_embeddings): Embedding(1024, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): LayoutLMEncoder(
      (layer): ModuleList(
        (0-11): 12 x LayoutLMLayer(
          (attention): LayoutLMAttention(
            (self): LayoutLMSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
 

### Number of treinable / non-treinable parameters

In [15]:
print(f'''
    Model Info
    -----------------
    
    Treinable params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}
    Non Treinable params: {sum(p.numel() for p in model.parameters() if not p.requires_grad)}

''')


    Model Info
    -----------------
    
    Treinable params: 112645655
    Non Treinable params: 0




In [54]:
optimizer = AdamW(model.parameters(), lr=5e-5)

global_step = 0
num_train_epochs = 4
t_total = len(dataloader) * num_train_epochs # total number of training steps 

#put the model in training mode
model.train()
for epoch in tqdm(range(num_train_epochs)):
   logging.info(f"Epoch: {epoch}")
   running_loss = 0.0
   accuracy = []
   for X in (dataloader):
      input_ids = X["input_ids"]\
         .to(device)\
         .squeeze()
      
      bbox = X["bbox"]\
         .to(device)\
         .squeeze()

      attention_mask = X["attention_mask"]\
         .to(device)\
         .squeeze()
      token_type_ids = X["token_type_ids"]\
         .to(device)\
         .squeeze()
      
      labels = X["labels"]\
         .to(device)\
         .squeeze()

      image = X["image"]\
         .to(device)\
         .squeeze()

      # forward pass
      outputs = model(
         input_ids=input_ids, 
         bbox=bbox, 
         attention_mask=attention_mask, 
         token_type_ids=token_type_ids,
         labels=labels
      )

      
      loss = outputs.loss

      running_loss += loss.item()
      predictions = outputs.logits.argmax(-1)
      valid_outputs_mask = labels != -100

      correct = (predictions == labels)\
         [valid_outputs_mask]\
         .float()\
         .sum()

      # backward pass to get the gradients 
      loss.backward()

      # update
      optimizer.step()
      optimizer.zero_grad()
      global_step += 1
      
      acc = correct / (labels[valid_outputs_mask]).shape[0]

      accuracy.append(acc)

   logging.info(f"Loss: {running_loss / len(dataloader.dataset)}")
   logging.info(f"Training accuracy: {torch.tensor(accuracy).mean().item()}", )

100%|██████████| 4/4 [01:38<00:00, 24.64s/it]


In [53]:
correct / (labels[valid_outputs_mask]).shape[0]

tensor(434.6750, device='cuda:0')

In [52]:
accuracy

[tensor(0.9455, device='cuda:0'),
 tensor(1.1360, device='cuda:0'),
 tensor(3.8085, device='cuda:0'),
 tensor(7.5185, device='cuda:0'),
 tensor(7.0645, device='cuda:0'),
 tensor(7.3529, device='cuda:0'),
 tensor(7.0976, device='cuda:0'),
 tensor(8.5263, device='cuda:0'),
 tensor(11.7333, device='cuda:0'),
 tensor(6.0435, device='cuda:0'),
 tensor(5.6517, device='cuda:0'),
 tensor(15.1714, device='cuda:0'),
 tensor(13.3023, device='cuda:0'),
 tensor(16.7222, device='cuda:0'),
 tensor(14.0870, device='cuda:0'),
 tensor(20.4848, device='cuda:0'),
 tensor(17.7500, device='cuda:0'),
 tensor(14.2075, device='cuda:0'),
 tensor(16.4792, device='cuda:0'),
 tensor(22.2703, device='cuda:0'),
 tensor(42.1500, device='cuda:0'),
 tensor(10.5176, device='cuda:0'),
 tensor(23.8462, device='cuda:0'),
 tensor(28.8182, device='cuda:0'),
 tensor(30.5938, device='cuda:0'),
 tensor(31.5312, device='cuda:0'),
 tensor(21.5510, device='cuda:0'),
 tensor(16.9394, device='cuda:0'),
 tensor(30.3947, device='cuda:

In [51]:
torch.tensor(accuracy).mean()

tensor(204.2037)

In [41]:
(predictions[labels != -100] == labels[labels!= -100]).sum()

tensor(5, device='cuda:0')

In [44]:
labels[labels!= -100].shape

torch.Size([55])

In [36]:
outputs.logits.argmax(-1)

tensor([[ 3,  5,  9,  ...,  3,  3,  7],
        [14, 22, 22,  ..., 14, 14,  3]], device='cuda:0')

In [25]:
outputs = model(
        input_ids=input_ids.squeeze(), 
        bbox=bbox, 
        attention_mask=attention_mask.squeeze(), 
        token_type_ids=token_type_ids.squeeze(),
        # image= image,
        labels=labels
      )

In [77]:
input_ids.shape

torch.Size([2, 1, 512])

In [78]:
bbox.shape

torch.Size([2, 512, 4])

In [79]:
attention_mask.shape

torch.Size([2, 1, 512])

In [80]:
token_type_ids.shape

torch.Size([2, 1, 512])

In [81]:

image.shape

torch.Size([2, 3, 224, 224])

In [86]:
torch.unique(labels)


tensor([-100,    1,    3,    5,    9,   10,   13,   14,   15,   16,   22])