<a href="https://colab.research.google.com/github/uakarsh/AkarshU/blob/master/examples/DocFormer_for_MLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Make the environment CUDA Enabled (so that, it would be easy to process everything)

### 1. About the Notebook:

This notebook, demonstrates using DocFormer for the purpose of Masked language Modeling (without pre-trained weights)

In [None]:
## Installing the dependencies (might take some time)

%%capture
!pip install pytesseract
!sudo apt install tesseract-ocr
!pip install transformers
!pip install pytorch-lightning
!pip install einops
!pip install accelerate
!pip install tqdm
!pip install torchmetrics

In [None]:
%%capture
!pip install 'Pillow==7.1.2'

In [None]:
## Cloning the repository

%%capture
!git clone https://github.com/shabie/docformer.git

In [None]:
## Importing the libraries

import os
import pickle
import pytesseract
import numpy as np
import pandas as pd
from PIL import Image,ImageDraw
import torch
from torchvision.transforms import ToTensor
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader

import math
import torch.nn.functional as F
import torchvision.models as models
from einops import rearrange
from torch import Tensor


## Adding the path of docformer to system path
import sys
sys.path.append('/content/docformer/src/docformer/')



## Importing the functions from the DocFormer Repo
from dataset import create_features
from modeling import DocFormerEncoder,ResNetFeatureExtractor,DocFormerEmbeddings
from transformers import BertTokenizerFast

In [None]:
## Setting some hyperparameters

device = 'cuda' if torch.cuda.is_available() else 'cpu'

config = {
  "coordinate_size": 96,              ## (768/8), 8 for each of the 8 coordinates of x, y
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "image_feature_pool_shape": [7, 7, 256],
  "intermediate_ff_size_factor": 4,
  "max_2d_position_embeddings": 1024,
  "max_position_embeddings": 512,
  "max_relative_positions": 8,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "shape_size": 96,
  "vocab_size": 30522,
  "layer_norm_eps": 1e-12,
}

## 2. Making the dataset

In [None]:
class DocumentDataset(Dataset):
    def __init__(self,entries,tokenizer,labels = None, use_mlm = False):

        self.use_mlm = use_mlm
        self.entries = entries
        self.labels = labels
        self.tokenizer = tokenizer
        self.config = config

    def __len__(self) -> int:
        return len(self.entries)
    
    def __getitem__(self,index):
        
        ''' 
        Returns only four required inputs, 
        * resized_scaled_img
        * input_ids
        * x_features
        * y_features

        If labels are not None, then labels also
        '''
        encoding = create_features(self.entries[index],self.tokenizer, apply_mask_for_mlm=self.use_mlm)

        if self.labels==None:

          if self.use_mlm:
            return encoding['resized_scaled_img'],encoding['input_ids'],encoding['x_features'],encoding['y_features'], encoding['mlm_labels']

          else:
            return encoding['resized_scaled_img'],encoding['input_ids'],encoding['x_features'],encoding['y_features']

        return encoding['resized_scaled_img'],encoding['input_ids'],encoding['x_features'],encoding['y_features'], self.labels[index]

In [None]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

##### Downloading the RVL-CDIP dataset, it contains few images for the purpose of MLM (from invoice classes of RVL-CDIP dataset)

In [None]:
%%capture
!git clone https://github.com/uakarsh/sample_rvl_cdip_dataset.git

In [None]:
base_path = '/content/sample_rvl_cdip_dataset/RVL-CDIP Invoice Class Dataset'
fp = pd.DataFrame({'image_id':[os.path.join(base_path,i) for i in os.listdir(base_path)]})

In [None]:
train_ds = DocumentDataset(fp['image_id'].values.tolist(),tokenizer = tokenizer, use_mlm = True)

def collate_fn(batch):
    return tuple(zip(*batch))

train_data_loader = DataLoader(train_ds,
                                batch_size=2,
                                shuffle=True,
                                num_workers=0,
                                collate_fn=collate_fn
                                )

## 3. Making the model and doing the propagation

In [None]:
class DocFormerForMLM(nn.Module):
  
    def __init__(self, config):
        super().__init__()

        self.resnet = ResNetFeatureExtractor()
        self.embeddings = DocFormerEmbeddings(config)
        self.lang_emb = LanguageFeatureExtractor()
        self.config = config
        self.dropout = nn.Dropout(config['hidden_dropout_prob'])
        self.linear_layer = nn.Linear(in_features = config['hidden_size'], out_features = config['vocab_size'])
        self.encoder = DocFormerEncoder(config)

    def forward(self, x_feat, y_feat, img, token):
        v_bar_s, t_bar_s = self.embeddings(x_feat,y_feat)
        v_bar = self.resnet(img)
        t_bar = self.lang_emb(token)
        out = self.encoder(t_bar,v_bar,t_bar_s,v_bar_s)
        out = self.linear_layer(out)

        return out

In [None]:
model = DocFormerForMLM(config).to(device)

In [None]:
## Using a single batch for the forward propagation

features = next(iter(train_data_loader))
final_data = []

for i in range(len(features)):
  final_data.append(torch.stack(features[i]))
  
del features
img,token,x_feat,y_feat, labels = final_data

In [None]:
## Transferring it to device

img = img.to(device)
token = token.to(device)
x_feat = x_feat.to(device)
y_feat = y_feat.to(device)
labels = labels.to(device)

In [None]:
## Forward Propagation

out = model(x_feat, y_feat, img, token)

In [None]:
## Initializing, the loss and optimizer

criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr= 5e-5)


## Calculating the loss and back propagating
loss = criterion(out.transpose(1,2), labels.long())
loss.backward()
optimizer.step()