<a href="https://colab.research.google.com/github/uakarsh/docformer/blob/master/examples/docformer_pl/3_Pre_training_DocFormer_Task_MLM_Task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
## Refer here for the dataset: https://github.com/furkanbiten/idl_data 
# (IDL dataset was also used in the pre-training of LaTr), might take time to download the dataset

!wget http://datasets.cvc.uab.es/UCSF_IDL/Samples/ocr_imgs_sample.zip
!unzip /content/ocr_imgs_sample.zip
!rm /content/ocr_imgs_sample.zip

In [None]:
## Installing the dependencies (might take some time)

%%capture
!pip install pytesseract
!sudo apt install tesseract-ocr
!pip install transformers
!pip install pytorch-lightning
!pip install einops
!pip install accelerate
!pip install tqdm
!pip install 'Pillow==7.1.2'

In [None]:
## Cloning the repository

%%capture
!git clone https://github.com/uakarsh/docformer.git

In [None]:
## Importing the libraries

import warnings
warnings.simplefilter("ignore", UserWarning)

import os
import pickle
import pytesseract
import numpy as np
import pandas as pd
from PIL import Image,ImageDraw
import json

import torch
from torchvision.transforms import ToTensor
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader

import math
import torch.nn.functional as F
import torchvision.models as models
from einops import rearrange
from torch import Tensor

## Adding the path of docformer to system path
import sys
sys.path.append('/content/docformer/src/docformer/')

## Importing the functions from the DocFormer Repo
from dataset import create_features, resize_align_bbox
from modeling import DocFormerEncoder,ResNetFeatureExtractor,DocFormerEmbeddings,LanguageFeatureExtractor
from transformers import BertTokenizerFast
from utils import load_json_file, get_specific_file

## 1. Dataset loading

In [None]:
## Making the json entries

json_path = '/content/sample/OCR'
pdf_path = '/content/sample/pdfs'
json_entries = []
resize_shape = (500, 384)

target_size = resize_shape
width, height = target_size

for i in os.listdir(json_path):
  base_path = os.path.join(json_path, i)
  for j in os.listdir(base_path):
    json_entries.append(os.path.join(base_path, j))

In [None]:
## Splitting the dataset

from sklearn.model_selection import train_test_split as tts
train_json_entries, val_json_entries = tts(json_entries, test_size = 0.2, random_state = 122, shuffle = True)

In [None]:
def get_words_and_coordinates(tif_path, sample_entry):
    ## Making the list for storing the words and coordinates
    words = []
    coordinates = []

    ## Storing the current box

    original_image_size = Image.open(tif_path).convert("RGB").size
    for i in sample_entry[1]['Blocks']:
      if i['BlockType']=='WORD' and i['Page']==1:
        words.append(i['Text'].lower())
        curr_box = i['Geometry']['BoundingBox']
        xmin, ymin, xmax, ymax = curr_box['Left'], curr_box['Top'], curr_box['Width']+ curr_box['Left'], curr_box['Height']+ curr_box['Top']
        curr_bbox =  resize_align_bbox(tuple([xmin, ymin, xmax, ymax]), 1, 1, *original_image_size)
        coordinates.append(curr_bbox)

    return words, coordinates

In [None]:
class DocumentDataset(Dataset):

  def __init__(self, json_entries, pdf_path, target_size, tokenizer, max_len = 512, use_mlm = False):
    self.json_entries = json_entries
    self.pdf_path = pdf_path
    self.target_size = target_size
    self.tokenizer = tokenizer
    self.max_len = max_len
    self.use_mlm = use_mlm

  def __len__(self):
    return len(self.json_entries)

  def __getitem__(self, idx):

    ## Loading json file
    sample_entry = load_json_file(json_entries[idx])

    ## Loading tif folder
    sample_tif_file = os.path.join(pdf_path, sample_entry[0].split('/')[-1])

    ## Loading the tif path
    tif_path = get_specific_file(sample_tif_file)

    words, coordinates = get_words_and_coordinates(tif_path, sample_entry)
    
    final_encoding = create_features(
            tif_path,
            tokenizer,
            add_batch_dim=False,
            target_size=self.target_size,
            max_seq_length=self.max_len,
            path_to_save=None,
            save_to_disk=False,
            apply_mask_for_mlm=self.use_mlm,
            extras_for_debugging=False,
            use_ocr = False,
            bounding_box = coordinates,
            words = words
    )

    return final_encoding

In [None]:
## Defining the tokenizer
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [None]:
train_dataset = DocumentDataset(
    json_entries = train_json_entries, 
    pdf_path = pdf_path, 
    target_size = target_size, 
    tokenizer = tokenizer,
    max_len = 512,
    use_mlm = True
)


val_dataset = DocumentDataset(
    json_entries = val_json_entries, 
    pdf_path = pdf_path, 
    target_size = target_size, 
    tokenizer = tokenizer,
    max_len = 512,
    use_mlm = True
)

In [None]:
def collate_fn(data_bunch):

  '''
  A function for the dataloader to return a batch dict of given keys

  data_bunch: List of dictionary
  '''

  dict_data_bunch = {}

  for i in data_bunch:
    for (key, value) in i.items():
      if key not in dict_data_bunch:
        dict_data_bunch[key] = []
      dict_data_bunch[key].append(value)

  for key in list(dict_data_bunch.keys()):
      dict_data_bunch[key] = torch.stack(dict_data_bunch[key], axis = 0)

  return dict_data_bunch

In [None]:
# train_dl = DataLoader(train_dataset, batch_size = 4, collate_fn = collate_fn)
# val_dl = DataLoader(val_dataset, batch_size = 4, collate_fn = collate_fn)z

## Defining the DataModule


In [None]:
import pytorch_lightning as pl

class DataModule(pl.LightningDataModule):

  def __init__(self, train_dataset, val_dataset,  batch_size = 2):

    super(DataModule, self).__init__()
    self.train_dataset = train_dataset
    self.val_dataset = val_dataset
    self.batch_size = batch_size

  def train_dataloader(self):
    return DataLoader(self.train_dataset, batch_size = self.batch_size, 
                      collate_fn = collate_fn, shuffle = True)
  
  def val_dataloader(self):
    return DataLoader(self.val_dataset, batch_size = self.batch_size,
                                  collate_fn = collate_fn, shuffle = False)

In [None]:
datamodule = DataModule(train_dataset, val_dataset)

## Modeling part

In [None]:
## Setting some hyperparameters

device = 'cuda' if torch.cuda.is_available() else 'cpu'

config = {
  "coordinate_size": 96,              ## (768/8), 8 for each of the 8 coordinates of x, y
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "image_feature_pool_shape": [7, 7, 256],
  "intermediate_ff_size_factor": 4,
  "max_2d_position_embeddings": 1024,
  "max_position_embeddings": 512,
  "max_relative_positions": 8,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "shape_size": 96,
  "vocab_size": 30522,
  "layer_norm_eps": 1e-12,
}

In [None]:
class DocFormerForMLM(nn.Module):
  
    def __init__(self, config):
      super(DocFormerForMLM, self).__init__()

      self.resnet = ResNetFeatureExtractor()
      self.embeddings = DocFormerEmbeddings(config)
      self.lang_emb = LanguageFeatureExtractor()
      self.config = config
      self.dropout = nn.Dropout(config['hidden_dropout_prob'])
      self.linear_layer = nn.Linear(in_features = config['hidden_size'], out_features = config['vocab_size'])
      self.encoder = DocFormerEncoder(config)

    def forward(self, batch_dict):

      x_feat = batch_dict['x_features']
      y_feat = batch_dict['y_features']

      token = batch_dict['input_ids']
      img = batch_dict['resized_scaled_img']

      v_bar_s, t_bar_s = self.embeddings(x_feat,y_feat)
      v_bar = self.resnet(img)
      t_bar = self.lang_emb(token)
      out = self.encoder(t_bar,v_bar,t_bar_s,v_bar_s)
      out = self.linear_layer(out)

      return out

In [None]:
## Defining pytorch lightning model
from sklearn.metrics import accuracy_score

class DocFormer(pl.LightningModule):

  def __init__(self, config , lr = 1e-3):
    super(DocFormer, self).__init__()
    
    self.config = config
    self.save_hyperparameters()
    self.docformer = DocFormerForMLM(config)
    self.training_losses = []
    self.validation_losses = []


  def calculate_accuracy_score(self, prediction, labels):

    ## Calculate the accuracy score between the prediction and ground label for a batch, with considering the pad sequence
    batch_size = len(prediction)
    ac_score = 0

    for (pred, gt) in zip(prediction, labels):
      ac_score+= accuracy_score(pred.cpu(), gt.cpu())
    ac_score = ac_score/batch_size
    return ac_score

  def forward(self, batch_dict):
    logits = self.docformer(batch_dict)
    return logits

  def training_step(self, batch, batch_idx):
    logits = self.forward(batch)

    ## https://discuss.huggingface.co/t/bertformaskedlm-s-loss-and-scores-how-the-loss-is-computed/607/2
    loss = nn.CrossEntropyLoss()(logits.view(-1,self.config['vocab_size']), batch['mlm_labels'].view(-1))
    _, preds = torch.max(logits, dim = -1)

    ## Calculating the accuracy score
    train_acc = self.calculate_accuracy_score(preds, batch['mlm_labels'])
    train_acc = torch.tensor(train_acc)

    ## Logging
    self.log('train_ce_loss', loss,prog_bar = True)
    self.log('val_acc', train_acc, prog_bar = True)
    self.training_losses.append(loss.item())

    return loss
  
  def validation_step(self, batch, batch_idx):
    logits = self.forward(batch)
    loss = nn.CrossEntropyLoss()(logits.view(-1,self.config['vocab_size']), batch['mlm_labels'].view(-1))
    _, preds = torch.max(logits, dim = -1)

    ## Validation Accuracy
    val_acc = self.calculate_accuracy_score(preds.cpu(), batch['mlm_labels'].cpu())
    val_acc = torch.tensor(val_acc)

    ## Logging
    self.log('val_ce_loss', loss, prog_bar = True)
    self.log('val_acc', val_acc, prog_bar = True)
    self.validation_losses.append(loss.item())

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr = self.hparams['lr'])

  def training_epoch_end(self, training_step_outputs):
    train_loss_mean = np.mean(self.training_losses)
    self.logger.experiment.add_scalar('training_loss', train_loss_mean, global_step=self.current_epoch)
    self.training_losses = []  # reset for next epoch

  def validation_epoch_end(self, validation_step_outputs):
    val_loss_mean = np.mean(self.training_losses)
    self.logger.experiment.add_scalar('validation_loss', val_loss_mean, global_step=self.current_epoch)
    self.validation_losses = []  # reset for next epoch
    

In [None]:
# docformer = DocFormer(config).to(device)

In [None]:
# sample_batch_dict = next(iter(datamodule.train_dataloader()))
# for key in list(sample_batch_dict.keys()):
#   sample_batch_dict[key] = sample_batch_dict[key].to(device)

# output = docformer.forward(sample_batch_dict)
# criteria = nn.CrossEntropyLoss()
# loss = criteria(output.view(-1,config['vocab_size']), sample_batch_dict['mlm_labels'].view(-1))
# loss
# output.shape, sample_batch_dict['mlm_labels'].shape

In [None]:
# _, preds = torch.max(output, dim = -1)

In [None]:
# batch_size = len(preds)
# net_acc = 0

# for (pred, gt) in zip(preds, sample_batch_dict['mlm_labels']):
#   if gt==0:
#     break
#   net_acc+=accuracy_score(pred.cpu(), gt.cpu())
# net_acc = net_acc/batch_size

## Until Now.....

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

def main():
    datamodule = DataModule(train_dataset, val_dataset)
    docformer = DocFormer(config)

    checkpoint_callback = ModelCheckpoint(
        dirpath="./models", monitor="val_ce_loss", mode="min"
    )
    early_stopping_callback = EarlyStopping(
        monitor="val_ce_loss", patience=3, verbose=True, mode="min"
    )

    trainer = pl.Trainer(
        default_root_dir="logs",
        gpus=(1 if torch.cuda.is_available() else 0),
        max_epochs=2,
        fast_dev_run=False,
        logger=pl.loggers.TensorBoardLogger("logs/", name="cola", version=1),
        callbacks=[checkpoint_callback, early_stopping_callback],
    )
    trainer.fit(docformer, datamodule)

In [None]:
if __name__ == "__main__":
    main()