<a href="https://colab.research.google.com/github/uakarsh/TiLT-Implementation/blob/main/how_did_i_prepare_the_stuffs/tilt_part_3_2_adding_a_final_touch_and_then_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/uakarsh/TiLT-Implementation.git

Cloning into 'TiLT-Implementation'...
remote: Enumerating objects: 116, done.[K
remote: Counting objects: 100% (116/116), done.[K
remote: Compressing objects: 100% (85/85), done.[K
remote: Total 116 (delta 52), reused 77 (delta 23), pack-reused 0[K
Receiving objects: 100% (116/116), 2.94 MiB | 2.57 MiB/s, done.
Resolving deltas: 100% (52/52), done.


In [None]:
!pip install -r /content/TiLT-Implementation/requirements.txt

In [None]:
import sys
sys.path.append("/content/TiLT-Implementation/src/")

In [None]:
from transformers import AutoTokenizer, AutoConfig
from datasets import load_dataset
import torch
import torch.nn as nn

from dataset import FUNSDDs
from torchvision import transforms
from tqdm.auto import tqdm

## Custom imports
from visual_backbone import Unet_encoder, RoIPool
from t5 import T5ForConditionalGeneration, T5Stack
from transformers import AutoModel

## 1.1. Preparing the dataset

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

hf_ds = load_dataset("nielsr/funsd-layoutlmv3")
model_name = "t5-base"
## Visual Embedding extractor's parameters
in_channels = 3
num_pool_layers = 3
channels = 16
sampling_ratio = 2
spatial_scale = 48 / 384
output_size = (3,3)
load_weights = True

## Tokenizer's parameter
model_max_length = 512

t5_config = AutoConfig.from_pretrained(model_name)
## Adding new parameters
t5_config.update(dict(in_channels = in_channels, num_pool_layers = num_pool_layers,  channels = channels, model_max_length = model_max_length,
                      output_size = output_size, spatial_scale = spatial_scale, sampling_ratio = sampling_ratio, use_cache = False, load_weights = load_weights,
                      lr =  2e-4))

## Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = True, model_max_length = model_max_length)

Downloading builder script:   0%|          | 0.00/5.13k [00:00<?, ?B/s]

Downloading and preparing dataset funsd-layoutlmv3/funsd to /root/.cache/huggingface/datasets/nielsr___funsd-layoutlmv3/funsd/1.0.0/0e3f4efdfd59aa1c3b4952c517894f7b1fc4d75c12ef01bcc8626a69e41c1bb9...


Downloading data:   0%|          | 0.00/16.8M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Dataset funsd-layoutlmv3 downloaded and prepared to /root/.cache/huggingface/datasets/nielsr___funsd-layoutlmv3/funsd/1.0.0/0e3f4efdfd59aa1c3b4952c517894f7b1fc4d75c12ef01bcc8626a69e41c1bb9. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

In [None]:
def get_id2label_and_label2id():
    label2id = {'O': 0, 'B-HEADER': 1, 'I-HEADER': 2, 'B-QUESTION': 3, 'I-QUESTION': 4, 'B-ANSWER': 5, 'I-ANSWER': 6}
    id2label = {0: 'O', 1: 'B-HEADER', 2: 'I-HEADER', 3: 'B-QUESTION', 4: 'I-QUESTION', 5: 'B-ANSWER', 6: 'I-ANSWER'}
    return id2label, label2id

def convert_id_to_label(list_of_label):
  return [id2label[x] for x in list_of_label]

In [None]:
id2label, label2id = get_id2label_and_label2id()
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Lambda(lambda x : 2 * x - 1)])

In [None]:
train_new_tags = list(map(lambda x : convert_id_to_label(x), hf_ds['train']['ner_tags']))
test_new_tags = list(map(lambda x : convert_id_to_label(x), hf_ds['test']['ner_tags']))

In [None]:
hf_ds['train'] = hf_ds['train'].remove_columns("ner_tags").add_column("ner_tags", train_new_tags)
hf_ds['test'] = hf_ds['test'].remove_columns("ner_tags").add_column("ner_tags", test_new_tags)

In [None]:
train_ds = FUNSDDs(hf_ds['train'],tokenizer = tokenizer, transform = transform)
val_ds = FUNSDDs(hf_ds['test'],tokenizer = tokenizer, transform = transform)

### 1.2 Writing the `collate_fn` for custom handling of the dataloader

In [None]:
class CollateFn(object):
  def __init__(self, tokenizer):
    self.tokenizer = tokenizer

  def __call__(self, list_of_ds):
    simple_keys = ["input_ids", "attention_mask", "bboxes", "pixel_values" ]
    actual_batch = {}
    for key in simple_keys:
      actual_batch[key] = torch.stack([x[key] for x in list_of_ds])
    
    actual_batch['labels'] = self.tokenizer.batch_encode_plus([x['labels'] for x in list_of_ds], return_tensors = 'pt', is_split_into_words = True,
                                                              padding='max_length', truncation = True)['input_ids']
    return actual_batch

In [None]:
collate_fn = CollateFn(tokenizer)

In [None]:
sample_batch_encoding = collate_fn([train_ds[0], train_ds[1]])
for key in sample_batch_encoding:
  sample_batch_encoding[key] = sample_batch_encoding[key].to(device)

## 2.1 Preparing the visual model

In [None]:
class VisualEmbedding(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.unet_encoder = Unet_encoder(in_channels = config.in_channels, channels = config.channels, num_pool_layers = config.num_pool_layers)
    self.roi_pool = RoIPool(output_size = config.output_size, spatial_scale = config.spatial_scale)
    self.proj = nn.Linear(in_features = 128 * 3 * 3, out_features = config.d_model)
    self.config = config

  def forward(self, pixel_values, bboxes):
    image_embedding = self.unet_encoder(pixel_values)
    feature_maps_bboxes = self.roi_pool(image_embedding, bboxes).flatten(2)
    projection = self.proj(feature_maps_bboxes)
    return projection

## 2.2 Preparing the semantic model

In [None]:
class TiLTTransformer(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.visual_embedding_extractor = VisualEmbedding(config)
    self.t5_model = T5ForConditionalGeneration(config)
    

  def generate(self, batch):
    total_embedding = self.common_step(batch)
    return self.t5_model.generate(input_embeds = total_embedding)

  def common_step(self, batch):
    ## Visual embedding
    visual_embedding = self.visual_embedding_extractor(pixel_values = batch['pixel_values'], bboxes = batch['bboxes'])

    ## Semantic embedding from t5_model's embedding layer
    semantic_embedding = self.t5_model.shared(batch['input_ids'])

    ## Net embedding is addition of both the embeddings
    total_embedding = visual_embedding + semantic_embedding

    return total_embedding

  def forward(self, batch):

    total_embedding = self.common_step(batch)

    ## This is then fed to t5_model
    final_output = self.t5_model(attention_mask = batch['attention_mask'], inputs_embeds = total_embedding,
                            labels = batch['labels'])
    
    return final_output

In [None]:
tilt_model = TiLTTransformer(t5_config).to(device)
output = tilt_model(sample_batch_encoding)

Downloading pytorch_model.bin:   0%|          | 0.00/892M [00:00<?, ?B/s]

Weights loaded successfully!


## 3.1 Preparing the metrics to evaluate the predictions

In [None]:
import evaluate

def create_batch_complete_labels(decoded_labels, class_labels = ['O', 'B-HEADER', 'I-HEADER', 'B-QUESTION', 'I-QUESTION', 'B-ANSWER', 'I-ANSWER']):

  ## class_labels should be changed for CORD dataset, currently it is for FUNSD

  completed_labels = []

  for single_prediction in decoded_labels:
    completed_single_label = []
    for single_label in single_prediction.split():
      flag = False
      for class_label in class_labels:
        if class_label.startswith(single_label):
          flag = True
          completed_single_label.append(class_label)
          break
      if not flag:
        completed_single_label.append("O")

    completed_labels.append(completed_single_label)
    return completed_labels

In [None]:
labels = sample_batch_encoding['labels']

In [None]:
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens = True)

In [None]:
completed_labels = create_batch_complete_labels(decoded_labels)

In [None]:
eval_metric = evaluate.load("seqeval")
metric = eval_metric.compute(predictions = completed_labels, references = completed_labels)

## Part: 4 Writing the `pytorch_lightning` code for training on FUNSD

In [None]:
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl

class TiltModel(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.tilt_model = TiLTTransformer(config)
        self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name)
        self.eval_metric = evaluate.load("seqeval")

    def forward(self, batch):
        return self.tilt_model(batch)

    def training_step(self, batch, batch_idx):
        output = self(batch)
        loss = output.loss

        predictions = output.logits.argmax(dim=-1)
        labels = batch['labels']
        decoded_labels = self.tokenizer.batch_decode(
            labels, skip_special_tokens=True)
        decoded_predictions = self.tokenizer.batch_decode(
            predictions, skip_special_tokens=True)

        results = self.eval_metric.compute(
            predictions=decoded_predictions, references=decoded_labels)

        self.log("train_loss", outputs.loss.item(),
                 prog_bar=True, on_epoch=True, logger=True)
        self.log("train_overall_fl",
                 results["overall_f1"], prog_bar=True, on_epoch=True, logger=True)
        self.log("train_overall_recall",
                 results["overall_recall"], prog_bar=True, on_epoch=True, logger=True)
        self.log("train_overall_precision",
                 results["overall_precision"], prog_bar=True, on_epoch=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        output = self(batch)
        loss = output.loss

        predictions = output.logits.argmax(dim=-1)
        labels = batch['labels']
        decoded_labels = self.tokenizer.batch_decode(
            labels, skip_special_tokens=True)
        decoded_predictions = self.tokenizer.batch_decode(
            predictions, skip_special_tokens=True)

        results = self.eval_metric.compute(
            predictions=decoded_predictions, references=decoded_labels)

        self.log("val_loss", outputs.loss.item(),
                 prog_bar=True, on_epoch=True, logger=True)
        self.log("val_overall_fl", results["overall_f1"],
                 prog_bar=True, on_epoch=True, logger=True)
        self.log("val_overall_recall",
                 results["overall_recall"], prog_bar=True, on_epoch=True, logger=True)
        self.log("val_overall_precision",
                 results["overall_precision"], prog_bar=True, on_epoch=True, logger=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.lr)
        return optimizer