<img src="https://images.unsplash.com/photo-1532153975070-2e9ab71f1b14?ixlib=rb-1.2.1&dl=annie-spratt-5cFwQ-WMcJU-unsplash.jpg&w=1920&q=80&fm=jpg&crop=entropy&cs=tinysrgb">


## 1. Introduction: 
* This notebook is a tutorial to the multi-modal architecture DocFormer (mainly for the purpose of Document Understanding).
* We would take in, the test-images of the RVL-CDIP Dataset, and then would train the model on a subset of the dataset
* We would also be logging the metrics with the help of Weights and Biases

## A small Introduction about the Model:

<img src = "https://github.com/uakarsh/docformer/raw/master/images/docformer-architecture.png">

DocFormer is a multi-modal transformer based architecture for the task of Visual Document Understanding (VDU). In addition, DocFormer is pre-trained in an unsupervised fashion using carefully designed tasks which encourage multi-modal interaction. DocFormer uses text, vision and spatial features and combines them using a novel multi-modal self-attention layer. DocFormer also shares learned spatial embeddings across modalities which makes it easy for the model to correlate text to visual tokens and vice versa. DocFormer is evaluated on 4 different datasets each with strong baselines. DocFormer achieves state-of-the-art results on all of them, sometimes beating models 4x its size (in no. of parameters).

For more understanding of the model and its code implementation, one can visit [here](https://github.com/uakarsh/docformer). So, let us go on to see what this model has to offer

The report for this entire run is attached [here](https://wandb.ai/iakarshu/RVL%20CDIP%20with%20DocFormer%20New%20Version/reports/Performance-of-DocFormer-with-RVL-CDIP-Test-Dataset--VmlldzoyMTI3NTM4)

<img src = "https://drive.google.com/u/1/uc?id=1IOyYXbU8bi5FDq59Z4RI1Qkoc54CzZto&export=download" >




### An Interactive Demo for the same can be found on 🤗 space [here](https://huggingface.co/spaces/iakarshu/docformer_for_document_classification)

### Installing the Libraries ⚙️:

In [None]:
## Installing the dependencies (might take some time)

!pip install -q pytesseract
!sudo apt install  -q tesseract-ocr
!pip install  -q transformers
!pip install  -q pytorch-lightning
!pip install  -q einops
!pip install  -q tqdm
!pip install  -q 'Pillow==7.1.2'
!pip install  -q datasets
!pip install wandb
!pip install torchmetrics

In [None]:
## Cloning the repository
!git clone https://github.com/uakarsh/docformer.git

In [None]:
## Logging into wandb

import wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_api")
wandb.login(key=secret_value_0)

## 2. Libraries 📘:

In [None]:
## Importing the libraries

import warnings
warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", RuntimeWarning)

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader

import torch.nn.functional as F
import torchvision.models as models

## Adding the path of docformer to system path
import sys
sys.path.append('./docformer/src/docformer/')

## Importing the functions from the DocFormer Repo
from dataset import create_features
from modeling import DocFormerEncoder,ResNetFeatureExtractor,DocFormerEmbeddings,LanguageFeatureExtractor
from transformers import BertTokenizerFast

In [None]:
## Hyperparameters

seed = 42
target_size = (500, 384)

## Setting some hyperparameters

device = 'cuda' if torch.cuda.is_available() else 'cpu'

## One can change this configuration and try out new combination
config = {
  "coordinate_size": 96,              ## (768/8), 8 for each of the 8 coordinates of x, y
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "image_feature_pool_shape": [7, 7, 256],
  "intermediate_ff_size_factor": 4,
  "max_2d_position_embeddings": 1024,
  "max_position_embeddings": 128,
  "max_relative_positions": 8,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "shape_size": 96,
  "vocab_size": 30522,
  "layer_norm_eps": 1e-12,
}

## A small note 🗒️: 
Here, for the purpose of Demo I would be using only 250 Images per class, and would train the model on it. Definintely for a data hungry model such as transformers, such a small data is not enough, but let us see what are the results on it.

In [None]:
from tqdm.auto import tqdm

## For the purpose of prediction
id2label = []
label2id = {}

curr_class = 0
## Preparing the Dataset
base_directory = '../input/the-rvlcdip-dataset-test/test'
dict_of_img_labels = {'img':[], 'label':[]}

max_sample_per_class = 250

for label in tqdm(os.listdir(base_directory)):
    img_path = os.path.join(base_directory, label)
    
    count = 0
    if label not in label2id:
        label2id[label] = curr_class
        curr_class+=1
        id2label.append(label)
        
    for img in os.listdir(img_path):
        if count>max_sample_per_class:
            break
            
        curr_img_path = os.path.join(img_path, img)
        dict_of_img_labels['img'].append(curr_img_path)
        dict_of_img_labels['label'].append(label2id[label])
        count+=1

In [None]:
import pandas as pd
df = pd.DataFrame(dict_of_img_labels)

In [None]:
from sklearn.model_selection import train_test_split as tts
train_df, valid_df = tts(df, random_state = seed, stratify = df['label'], shuffle = True)

In [None]:
train_df = train_df.reset_index().drop(columns = ['index'], axis = 1)
valid_df = valid_df.reset_index().drop(columns = ['index'], axis = 1)

## 3. Making the dataset 💽:

The main idea behind making the dataset is, to pre-process the input into a given format, and then provide the input to the model. So, simply just the image path, and the other configurations, and boom 💥, you would get the desired pre-processed input

In [None]:
## Creating the dataset

class RVLCDIPData(Dataset):
    
    def __init__(self, image_list, label_list, target_size, tokenizer, max_len = 512, transform = None):
        
        self.image_list = image_list
        self.label_list = label_list
        self.target_size = target_size
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.transform = transform
        
    def __len__(self):
        return len(self.image_list)
    
    def __getitem__(self, idx):
        img_path = self.image_list[idx]
        label = self.label_list[idx]
        
        ## More on this, in the repo mentioned previously
        final_encoding = create_features(
            img_path,
            self.tokenizer,
            add_batch_dim=False,
            target_size=self.target_size,
            max_seq_length=self.max_len,
            path_to_save=None,
            save_to_disk=False,
            apply_mask_for_mlm=False,
            extras_for_debugging=False,
            use_ocr = True
    )
        if self.transform is not None:
            ## Note that, ToTensor is already applied on the image
            final_encoding['resized_scaled_img'] = self.transform(final_encoding['resized_scaled_img'])
        
        
        keys_to_reshape = ['x_features', 'y_features', 'resized_and_aligned_bounding_boxes']
        for key in keys_to_reshape:
            final_encoding[key] = final_encoding[key][:self.max_len]
            
        final_encoding['label'] = torch.as_tensor(label).long()
        return final_encoding

In [None]:
## Defining the tokenizer
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

In [None]:
from torchvision import transforms

## Normalization to these mean and std (I have seen some tutorials used this, and also in image reconstruction, so used it)
transform = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                              

In [None]:
train_ds = RVLCDIPData(train_df['img'].tolist(), train_df['label'].tolist(),
                      target_size, tokenizer, config['max_position_embeddings'], transform)
val_ds = RVLCDIPData(valid_df['img'].tolist(), valid_df['label'].tolist(),
                      target_size, tokenizer,config['max_position_embeddings'],  transform)

### Collate Function:

Definitely collate function is an amazing function for using the dataloader as per our wish. More on collate function can be known from [here](https://stackoverflow.com/questions/65279115/how-to-use-collate-fn-with-dataloaders)

In [None]:
def collate_fn(data_bunch):

  '''
  A function for the dataloader to return a batch dict of given keys

  data_bunch: List of dictionary
  '''

  dict_data_bunch = {}

  for i in data_bunch:
    for (key, value) in i.items():
      if key not in dict_data_bunch:
        dict_data_bunch[key] = []
      dict_data_bunch[key].append(value)

  for key in list(dict_data_bunch.keys()):
      dict_data_bunch[key] = torch.stack(dict_data_bunch[key], axis = 0)

  return dict_data_bunch

## 4. Defining the DataModule 📖

* A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data:

* A DataModule is simply a collection of a train_dataloader(s), val_dataloader(s), test_dataloader(s) and predict_dataloader(s) along with the matching transforms and data processing/downloads steps required.




In [None]:
import pytorch_lightning as pl

class DataModule(pl.LightningDataModule):

  def __init__(self, train_dataset, val_dataset,  batch_size = 4):

    super(DataModule, self).__init__()
    self.train_dataset = train_dataset
    self.val_dataset = val_dataset
    self.batch_size = batch_size

  def train_dataloader(self):
    return DataLoader(self.train_dataset, batch_size = self.batch_size, 
                      collate_fn = collate_fn, shuffle = True)
  
  def val_dataloader(self):
    return DataLoader(self.val_dataset, batch_size = self.batch_size,
                                  collate_fn = collate_fn, shuffle = False)

In [None]:
datamodule = DataModule(train_ds, val_ds)

## 5. Modeling Part 🏎️

1. Firstly, we would define the pytorch model with our configurations, in which the class labels would be ranging from 0 to 15
2. Secondly, we would encode it in the PyTorch Lightening module, and boom 💥 our work of defining the model is done

In [None]:
class DocFormerForClassification(nn.Module):
  
    def __init__(self, config):
      super(DocFormerForClassification, self).__init__()

      self.resnet = ResNetFeatureExtractor(hidden_dim = config['max_position_embeddings'])
      self.embeddings = DocFormerEmbeddings(config)
      self.lang_emb = LanguageFeatureExtractor()
      self.config = config
      self.dropout = nn.Dropout(config['hidden_dropout_prob'])
      self.linear_layer = nn.Linear(in_features = config['hidden_size'], out_features = len(id2label))  ## Number of Classes
      self.encoder = DocFormerEncoder(config)

    def forward(self, batch_dict):

      x_feat = batch_dict['x_features']
      y_feat = batch_dict['y_features']

      token = batch_dict['input_ids']
      img = batch_dict['resized_scaled_img']

      v_bar_s, t_bar_s = self.embeddings(x_feat,y_feat)
      v_bar = self.resnet(img)
      t_bar = self.lang_emb(token)
      out = self.encoder(t_bar,v_bar,t_bar_s,v_bar_s)
      out = self.linear_layer(out)
      out = out[:, 0, :]
      return out

In [None]:
## Defining pytorch lightning model
from sklearn.metrics import accuracy_score, confusion_matrix
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torchmetrics

class DocFormer(pl.LightningModule):

  def __init__(self, config , lr = 5e-5):
    super(DocFormer, self).__init__()
    
    self.save_hyperparameters()
    self.config = config
    self.docformer = DocFormerForClassification(config)
    
    self.num_classes = len(id2label)
    self.train_accuracy_metric = torchmetrics.Accuracy()
    self.val_accuracy_metric = torchmetrics.Accuracy()
    self.f1_metric = torchmetrics.F1Score(num_classes=self.num_classes)
    self.precision_macro_metric = torchmetrics.Precision(
            average="macro", num_classes=self.num_classes
        )
    self.recall_macro_metric = torchmetrics.Recall(
            average="macro", num_classes=self.num_classes
        )
    self.precision_micro_metric = torchmetrics.Precision(average="micro")
    self.recall_micro_metric = torchmetrics.Recall(average="micro")

  def forward(self, batch_dict):
    logits = self.docformer(batch_dict)
    return logits

  def training_step(self, batch, batch_idx):
    logits = self.forward(batch)

    loss = nn.CrossEntropyLoss()(logits, batch['label'])
    preds = torch.argmax(logits, 1)

    ## Calculating the accuracy score
    train_acc = self.train_accuracy_metric(preds, batch["label"])

    ## Logging
    self.log('train/loss', loss,prog_bar = True, on_epoch=True, logger=True, on_step=True)
    self.log('train/acc', train_acc, prog_bar = True, on_epoch=True, logger=True, on_step=True)

    return loss
  
  def validation_step(self, batch, batch_idx):
    logits = self.forward(batch)
    loss = nn.CrossEntropyLoss()(logits, batch['label'])
    preds = torch.argmax(logits, 1)
    
    labels = batch['label']
    # Metrics
    valid_acc = self.val_accuracy_metric(preds, labels)
    precision_macro = self.precision_macro_metric(preds, labels)
    recall_macro = self.recall_macro_metric(preds, labels)
    precision_micro = self.precision_micro_metric(preds, labels)
    recall_micro = self.recall_micro_metric(preds, labels)
    f1 = self.f1_metric(preds, labels)

    # Logging metrics
    self.log("valid/loss", loss, prog_bar=True, on_step=True, logger=True)
    self.log("valid/acc", valid_acc, prog_bar=True, on_epoch=True, logger=True, on_step=True)
    self.log("valid/precision_macro", precision_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
    self.log("valid/recall_macro", recall_macro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
    self.log("valid/precision_micro", precision_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
    self.log("valid/recall_micro", recall_micro, prog_bar=True, on_epoch=True, logger=True, on_step=True)
    self.log("valid/f1", f1, prog_bar=True, on_epoch=True)
    
    return {"label": batch['label'], "logits": logits}

  def validation_epoch_end(self, outputs):
        labels = torch.cat([x["label"] for x in outputs])
        logits = torch.cat([x["logits"] for x in outputs])
        preds = torch.argmax(logits, 1)

        wandb.log({"cm": wandb.sklearn.plot_confusion_matrix(labels.cpu().numpy(), preds.cpu().numpy())})
        self.logger.experiment.log(
            {"roc": wandb.plot.roc_curve(labels.cpu().numpy(), logits.cpu().numpy())}
        )
        
  def configure_optimizers(self):
    return torch.optim.AdamW(self.parameters(), lr = self.hparams['lr'])

## 6. Summing it up and running the entire procedure 🏃

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger

def main():
    datamodule = DataModule(train_ds, val_ds)
    docformer = DocFormer(config)

    checkpoint_callback = ModelCheckpoint(
        dirpath="./models", monitor="valid/loss", mode="min"
    )
    early_stopping_callback = EarlyStopping(
        monitor="valid/loss", patience=3, verbose=True, mode="min"
    )
    
    wandb.init(config=config, project="RVL CDIP with DocFormer New Version")
    wandb_logger = WandbLogger(project="RVL CDIP with DocFormer New Version", entity="iakarshu")
    ## https://www.tutorialexample.com/implement-reproducibility-in-pytorch-lightning-pytorch-lightning-tutorial/
    pl.seed_everything(seed, workers=True)
    trainer = pl.Trainer(
        default_root_dir="logs",
        gpus=(1 if torch.cuda.is_available() else 0),
        max_epochs=1,
        fast_dev_run=False,
        logger=wandb_logger,
        callbacks=[checkpoint_callback, early_stopping_callback],
        deterministic=True
    )
    trainer.fit(docformer, datamodule)

In [None]:
if __name__ == "__main__":
    main()

## References:

1. [MLOps Repo](https://github.com/graviraja/MLOps-Basics) (For the integration of model and data with PyTorch Lightening) 
2. [PyTorch Lightening Docs](https://pytorch-lightning.readthedocs.io/en/stable/index.html) For all the doubts and bugs
3. [My Repo](https://github.com/uakarsh/docformer) For downloading the model and pre-processing steps
4. Unspash for Images
5. Google for other stuffs