In [1]:
!pip install -qqq wandb

In [2]:
## Logging into wandb

import wandb
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("wandb_api")
wandb.login(key=secret_value_0)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [3]:
!git clone https://github.com/uakarsh/TiLT-Implementation.git

Cloning into 'TiLT-Implementation'...
remote: Enumerating objects: 54, done.[K
remote: Counting objects: 100% (54/54), done.[K
remote: Compressing objects: 100% (43/43), done.[K
remote: Total 54 (delta 19), reused 41 (delta 9), pack-reused 0[K
Receiving objects: 100% (54/54), 270.58 KiB | 4.10 MiB/s, done.
Resolving deltas: 100% (19/19), done.


In [4]:
!pip install -r ./TiLT-Implementation/requirements.txt

Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l- \ done
[?25hCollecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: seqeval
  Building wheel for seqeval (setup.py) ... [?25l- \ | done
[?25h  Created wheel for seqeval: filename=seqeval-1.2.2-py3-none-any.whl size=16179 sha256=64a306bf144018138366778f6980776510dc4c5903cd6b96142dd6a60fea7159
  Stored in directory: /root/.cache/pip/wheels/b2/a1/b7/0d3b008d0c77cd57332d724b92cf7650b4185b493dc785f00a
Successfully built seqeval
Installing collected packages: seqeval, evaluate
Successfully installed evaluate-0.4.0 seqeval-1.2.2
[0m

In [5]:
import sys
sys.path.append("./TiLT-Implementation/src/")

In [6]:
import os
from transformers import AutoTokenizer, AutoConfig
from datasets import load_from_disk
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from pytorch_lightning.loggers import CSVLogger, WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl

from dataset import ExtFUNSDDs
from torchvision import transforms
from tqdm.auto import tqdm

## Custom imports
from visual_backbone import Unet_encoder, RoIPool
from t5 import T5ForConditionalGenerationAbstractive, T5Stack
from transformers import AutoModel

## 1.1. Preparing the dataset

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"

hf_ds = load_from_disk("/kaggle/input/cord-dataset/cord_dataset")
hf_ds = hf_ds.rename_columns({'words':'tokens'})

model_name = "t5-base"
## Visual Embedding extractor's parameters
in_channels = 3
num_pool_layers = 3
channels = 16
sampling_ratio = 2
spatial_scale = 48 / 384
output_size = (3,3)
load_weights = True
max_epochs = 20

## CORD Dataset specific
num_classes = 61

## Tokenizer's parameter
model_max_length = 512

t5_config = AutoConfig.from_pretrained(model_name)
## Adding new parameters
t5_config.update(dict(in_channels = in_channels, num_pool_layers = num_pool_layers,  channels = channels, model_max_length = model_max_length,
                      output_size = output_size, spatial_scale = spatial_scale, sampling_ratio = sampling_ratio, use_cache = False, load_weights = load_weights,
                      lr =  2e-4, num_classes = num_classes,max_epochs = max_epochs))

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

In [8]:
def get_id2label_and_label2id():
    label2id = {'O': 0,
 'B-MENU.NM': 1,
 'B-MENU.NUM': 2,
 'B-MENU.UNITPRICE': 3,
 'B-MENU.CNT': 4,
 'B-MENU.DISCOUNTPRICE': 5,
 'B-MENU.PRICE': 6,
 'B-MENU.ITEMSUBTOTAL': 7,
 'B-MENU.VATYN': 8,
 'B-MENU.ETC': 9,
 'B-MENU.SUB_NM': 10,
 'B-MENU.SUB_UNITPRICE': 11,
 'B-MENU.SUB_CNT': 12,
 'B-MENU.SUB_PRICE': 13,
 'B-MENU.SUB_ETC': 14,
 'B-VOID_MENU.NM': 15,
 'B-VOID_MENU.PRICE': 16,
 'B-SUB_TOTAL.SUBTOTAL_PRICE': 17,
 'B-SUB_TOTAL.DISCOUNT_PRICE': 18,
 'B-SUB_TOTAL.SERVICE_PRICE': 19,
 'B-SUB_TOTAL.OTHERSVC_PRICE': 20,
 'B-SUB_TOTAL.TAX_PRICE': 21,
 'B-SUB_TOTAL.ETC': 22,
 'B-TOTAL.TOTAL_PRICE': 23,
 'B-TOTAL.TOTAL_ETC': 24,
 'B-TOTAL.CASHPRICE': 25,
 'B-TOTAL.CHANGEPRICE': 26,
 'B-TOTAL.CREDITCARDPRICE': 27,
 'B-TOTAL.EMONEYPRICE': 28,
 'B-TOTAL.MENUTYPE_CNT': 29,
 'B-TOTAL.MENUQTY_CNT': 30,
 'I-MENU.NM': 31,
 'I-MENU.NUM': 32,
 'I-MENU.UNITPRICE': 33,
 'I-MENU.CNT': 34,
 'I-MENU.DISCOUNTPRICE': 35,
 'I-MENU.PRICE': 36,
 'I-MENU.ITEMSUBTOTAL': 37,
 'I-MENU.VATYN': 38,
 'I-MENU.ETC': 39,
 'I-MENU.SUB_NM': 40,
 'I-MENU.SUB_UNITPRICE': 41,
 'I-MENU.SUB_CNT': 42,
 'I-MENU.SUB_PRICE': 43,
 'I-MENU.SUB_ETC': 44,
 'I-VOID_MENU.NM': 45,
 'I-VOID_MENU.PRICE': 46,
 'I-SUB_TOTAL.SUBTOTAL_PRICE': 47,
 'I-SUB_TOTAL.DISCOUNT_PRICE': 48,
 'I-SUB_TOTAL.SERVICE_PRICE': 49,
 'I-SUB_TOTAL.OTHERSVC_PRICE': 50,
 'I-SUB_TOTAL.TAX_PRICE': 51,
 'I-SUB_TOTAL.ETC': 52,
 'I-TOTAL.TOTAL_PRICE': 53,
 'I-TOTAL.TOTAL_ETC': 54,
 'I-TOTAL.CASHPRICE': 55,
 'I-TOTAL.CHANGEPRICE': 56,
 'I-TOTAL.CREDITCARDPRICE': 57,
 'I-TOTAL.EMONEYPRICE': 58,
 'I-TOTAL.MENUTYPE_CNT': 59,
 'I-TOTAL.MENUQTY_CNT': 60}
    id2label = {0: 'O',
 1: 'B-MENU.NM',
 2: 'B-MENU.NUM',
 3: 'B-MENU.UNITPRICE',
 4: 'B-MENU.CNT',
 5: 'B-MENU.DISCOUNTPRICE',
 6: 'B-MENU.PRICE',
 7: 'B-MENU.ITEMSUBTOTAL',
 8: 'B-MENU.VATYN',
 9: 'B-MENU.ETC',
 10: 'B-MENU.SUB_NM',
 11: 'B-MENU.SUB_UNITPRICE',
 12: 'B-MENU.SUB_CNT',
 13: 'B-MENU.SUB_PRICE',
 14: 'B-MENU.SUB_ETC',
 15: 'B-VOID_MENU.NM',
 16: 'B-VOID_MENU.PRICE',
 17: 'B-SUB_TOTAL.SUBTOTAL_PRICE',
 18: 'B-SUB_TOTAL.DISCOUNT_PRICE',
 19: 'B-SUB_TOTAL.SERVICE_PRICE',
 20: 'B-SUB_TOTAL.OTHERSVC_PRICE',
 21: 'B-SUB_TOTAL.TAX_PRICE',
 22: 'B-SUB_TOTAL.ETC',
 23: 'B-TOTAL.TOTAL_PRICE',
 24: 'B-TOTAL.TOTAL_ETC',
 25: 'B-TOTAL.CASHPRICE',
 26: 'B-TOTAL.CHANGEPRICE',
 27: 'B-TOTAL.CREDITCARDPRICE',
 28: 'B-TOTAL.EMONEYPRICE',
 29: 'B-TOTAL.MENUTYPE_CNT',
 30: 'B-TOTAL.MENUQTY_CNT',
 31: 'I-MENU.NM',
 32: 'I-MENU.NUM',
 33: 'I-MENU.UNITPRICE',
 34: 'I-MENU.CNT',
 35: 'I-MENU.DISCOUNTPRICE',
 36: 'I-MENU.PRICE',
 37: 'I-MENU.ITEMSUBTOTAL',
 38: 'I-MENU.VATYN',
 39: 'I-MENU.ETC',
 40: 'I-MENU.SUB_NM',
 41: 'I-MENU.SUB_UNITPRICE',
 42: 'I-MENU.SUB_CNT',
 43: 'I-MENU.SUB_PRICE',
 44: 'I-MENU.SUB_ETC',
 45: 'I-VOID_MENU.NM',
 46: 'I-VOID_MENU.PRICE',
 47: 'I-SUB_TOTAL.SUBTOTAL_PRICE',
 48: 'I-SUB_TOTAL.DISCOUNT_PRICE',
 49: 'I-SUB_TOTAL.SERVICE_PRICE',
 50: 'I-SUB_TOTAL.OTHERSVC_PRICE',
 51: 'I-SUB_TOTAL.TAX_PRICE',
 52: 'I-SUB_TOTAL.ETC',
 53: 'I-TOTAL.TOTAL_PRICE',
 54: 'I-TOTAL.TOTAL_ETC',
 55: 'I-TOTAL.CASHPRICE',
 56: 'I-TOTAL.CHANGEPRICE',
 57: 'I-TOTAL.CREDITCARDPRICE',
 58: 'I-TOTAL.EMONEYPRICE',
 59: 'I-TOTAL.MENUTYPE_CNT',
 60: 'I-TOTAL.MENUQTY_CNT'}
    return id2label, label2id

def convert_id_to_label(list_of_label):
  return [id2label[x] for x in list_of_label]

In [9]:
# train_new_tags = list(map(lambda x : convert_id_to_label(x), hf_ds['train']['ner_tags']))
# test_new_tags = list(map(lambda x : convert_id_to_label(x), hf_ds['test']['ner_tags']))

In [10]:
# hf_ds['train'] = hf_ds['train'].remove_columns("ner_tags").add_column("ner_tags", train_new_tags)
# hf_ds['test'] = hf_ds['test'].remove_columns("ner_tags").add_column("ner_tags", test_new_tags)

### 1.2 Writing the `collate_fn` for custom handling of the dataloader

In [11]:
class CollateFn(object):
  def __init__(self, tokenizer):
    self.tokenizer = tokenizer

  def __call__(self, list_of_ds):
    simple_keys = ["input_ids", "attention_mask", "bboxes", "pixel_values", "labels"]
    actual_batch = {}
    for key in simple_keys:
      actual_batch[key] = torch.stack([x[key] for x in list_of_ds])
    
    # actual_batch['labels'] = self.tokenizer.batch_encode_plus([x['labels'] for x in list_of_ds], return_tensors = 'pt', is_split_into_words = True,
    #                                                           padding='max_length', truncation = True)['input_ids']
    return actual_batch

In [12]:
# sample_batch_encoding = collate_fn([train_ds[0], train_ds[1]])
# for key in sample_batch_encoding:
#   sample_batch_encoding[key] = sample_batch_encoding[key].to(device)

## 2.1 Preparing the visual model

In [13]:
class VisualEmbedding(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.unet_encoder = Unet_encoder(in_channels = config.in_channels, channels = config.channels, num_pool_layers = config.num_pool_layers)
    self.roi_pool = RoIPool(output_size = config.output_size, spatial_scale = config.spatial_scale)
    self.proj = nn.Linear(in_features = 128 * 3 * 3, out_features = config.d_model)
    self.config = config

  def forward(self, pixel_values, bboxes):
    image_embedding = self.unet_encoder(pixel_values)
    feature_maps_bboxes = self.roi_pool(image_embedding, bboxes).flatten(2)
    projection = self.proj(feature_maps_bboxes)
    return projection

## 2.2 Preparing the semantic model

In [14]:
class TiLTTransformer(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.visual_embedding_extractor = VisualEmbedding(config)
    self.t5_model = T5ForConditionalGenerationAbstractive(config)
    

  def generate(self, batch):
    total_embedding = self.common_step(batch)
    return self.t5_model.generate(input_embeds = total_embedding)

  def common_step(self, batch):
    ## Visual embedding
    visual_embedding = self.visual_embedding_extractor(pixel_values = batch['pixel_values'], bboxes = batch['bboxes'])

    ## Semantic embedding from t5_model's embedding layer
    semantic_embedding = self.t5_model.shared(batch['input_ids'])

    ## Net embedding is addition of both the embeddings
    total_embedding = visual_embedding + semantic_embedding

    return total_embedding

  def forward(self, batch):

    total_embedding = self.common_step(batch)

    ## This is then fed to t5_model
    final_output = self.t5_model(attention_mask = batch['attention_mask'], inputs_embeds = total_embedding,
                            labels = batch['labels'])
    
    return final_output

In [15]:
# tilt_model = TiLTTransformer(t5_config).to(device)
# output = tilt_model(sample_batch_encoding)

## 3.1 Preparing the metrics to evaluate the predictions

In [16]:
import evaluate

def get_labels(predictions, references):

    # Transform predictions and references tensors to numpy arrays
    if predictions.device.type == "cpu":
        y_pred = predictions.detach().clone().numpy()
        y_true = references.detach().clone().numpy()

    else:
        y_pred = predictions.detach().cpu().clone().numpy()
        y_true = references.detach().cpu().clone().numpy()

    # Remove ignored index (special tokens)
    true_predictions = [
        [id2label[p] for (p, l) in zip(pred, gold_label) if l != -100]
        for pred, gold_label in zip(y_pred, y_true)
    ]
    true_labels = [
        [id2label[l] for (p, l) in zip(pred, gold_label) if l != -100]
        for pred, gold_label in zip(y_pred, y_true)
    ]
    return true_predictions, true_labels

In [17]:
# labels = sample_batch_encoding['labels']

In [18]:
# true_predictions, true_labels = get_labels(predictions = output.logits.argmax(axis = -1), references = labels)

In [19]:
# eval_metric = evaluate.load("seqeval")
# metric = eval_metric.compute(predictions = true_predictions, references = true_labels)

## Part: 4 Writing the `pytorch_lightning` code for training on FUNSD

In [20]:
id2label, label2id = get_id2label_and_label2id()
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Lambda(lambda x : 2 * x - 1)])

## Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = True, model_max_length = model_max_length)

train_ds = ExtFUNSDDs(hf_ds['train'],tokenizer = tokenizer, transform = transform)
val_ds = ExtFUNSDDs(hf_ds['test'],tokenizer = tokenizer, transform = transform)

collate_fn = CollateFn(tokenizer)

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

In [21]:
class DataModule(pl.LightningDataModule):

  def __init__(self, train_dataset, eval_dataset, batch_size: int = 2):
    super(DataModule, self).__init__()
    self.batch_size = batch_size
    self.train_dataset = train_dataset
    self.eval_dataset = eval_dataset

  def train_dataloader(self):
    return DataLoader(self.train_dataset, batch_size=self.batch_size,
                      shuffle=True, collate_fn = collate_fn)

  def val_dataloader(self):
    return DataLoader(self.eval_dataset, batch_size=self.batch_size,
                      shuffle=False, collate_fn = collate_fn)

In [22]:
class TiltModel(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.tilt_model = TiLTTransformer(config)
        self.eval_metric = evaluate.load("seqeval")

    def forward(self, batch):
        return self.tilt_model(batch)

    def training_step(self, batch, batch_idx):
        output = self(batch)
        loss = output.loss

        predictions = output.logits.argmax(dim=-1)        
        true_predictions, true_labels = get_labels(predictions = predictions, references =  batch['labels'])

        results = self.eval_metric.compute(predictions=true_predictions, references=true_labels)
        self.log("train_loss", output.loss.item(),
                 prog_bar=True, on_epoch=True, logger=True)
        self.log("train_overall_fl",
                 results["overall_f1"], prog_bar=True, on_epoch=True, logger=True)
        self.log("train_overall_recall",
                 results["overall_recall"], prog_bar=True, on_epoch=True, logger=True)
        self.log("train_overall_precision",
                 results["overall_precision"], prog_bar=True, on_epoch=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        output = self(batch)
        loss = output.loss

        predictions = output.logits.argmax(dim=-1)        
        true_predictions, true_labels = get_labels(predictions = predictions, references =  batch['labels'])
        results = self.eval_metric.compute(predictions=true_predictions, references=true_labels)

        self.log("val_loss", output.loss.item(),
                 prog_bar=True, on_epoch=True, logger=True)
        self.log("val_overall_fl", results["overall_f1"],
                 prog_bar=True, on_epoch=True, logger=True)
        self.log("val_overall_recall",
                 results["overall_recall"], prog_bar=True, on_epoch=True, logger=True)
        self.log("val_overall_precision",
                 results["overall_precision"], prog_bar=True, on_epoch=True, logger=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.lr)
        return optimizer

In [23]:
def perform_evaluation(path: str=None, pl_model=None, pl_dl=None):
    print("Evaluating the model")
    if path is not None:
      pl_model = pl_model.load_from_checkpoint(path, config = t5_config)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    eval_metric = evaluate.load("seqeval")
    pl_model = pl_model.to(device)
    pl_model.eval();

    for idx, batch in enumerate(tqdm(pl_dl.val_dataloader())):
      # move batch to device
      batch = {k: v.to(device) for k, v in batch.items()}
      
      with torch.no_grad():
        outputs = pl_model(batch)

        predictions = outputs.logits.argmax(-1)
        true_predictions, true_labels = get_labels(predictions, batch["labels"])
        eval_metric.add_batch(references=true_labels,predictions=true_predictions)

    results = eval_metric.compute()

    metrics = {}
    for key in ['overall_precision', 'overall_recall', 'overall_f1', 'overall_accuracy']:
      print_statement = '{0: <30}'.format(str(key) + " has value:")
      print(print_statement, results[key])
      metrics[key] = results[key]

    return metrics

In [24]:
checkpoint_callback = ModelCheckpoint(dirpath="./tilt/models",monitor="val_overall_fl", mode="max", filename='tilt_best_ckpt', save_top_k = 1)
logger = CSVLogger("./tilt/logs", name="funsd_dataset")

In [25]:
wandb.init(config=t5_config, project="TiLT on CORD")
wandb_logger = WandbLogger(project="TiLT on CORD", log_model = False, entity="iakarshu")

[34m[1mwandb[0m: Currently logged in as: [33miakarshu[0m. Use [1m`wandb login --relogin`[0m to force relogin


  "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"


In [26]:
trainer = pl.Trainer(default_root_dir="./tilt/logs", devices="auto",
                     accelerator="auto",max_epochs= t5_config.max_epochs, 
                     logger=wandb_logger,callbacks=[checkpoint_callback])

In [27]:
pl_model = TiltModel(t5_config)
pl_dl = DataModule(train_ds, val_ds, batch_size = 2)

Downloading pytorch_model.bin:   0%|          | 0.00/892M [00:00<?, ?B/s]

Weights loaded successfully!


Downloading builder script:   0%|          | 0.00/6.34k [00:00<?, ?B/s]

In [28]:
trainer.fit(pl_model, pl_dl)

Sanity Checking: 0it [00:00, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [29]:
ckpt_folder = "./tilt/models"
if os.path.exists(ckpt_folder):
  ckpt_path = os.path.join(ckpt_folder, os.listdir(ckpt_folder)[0])
else:
  ckpt_path = None

In [30]:
metrics = perform_evaluation(path = ckpt_path, pl_model = pl_model, pl_dl = pl_dl)

Evaluating the model
Weights loaded successfully!


  0%|          | 0/50 [00:00<?, ?it/s]

overall_precision has value:   0.6480859637340497
overall_recall has value:      0.6264199935086011
overall_f1 has value:          0.637068823238158
overall_accuracy has value:    0.8052400146573837


In [31]:
print(metrics)

{'overall_precision': 0.6480859637340497, 'overall_recall': 0.6264199935086011, 'overall_f1': 0.637068823238158, 'overall_accuracy': 0.8052400146573837}
