In [1]:
!git clone https://github.com/uakarsh/latr.git

Cloning into 'latr'...
remote: Enumerating objects: 301, done.[K
remote: Counting objects: 100% (137/137), done.[K
remote: Compressing objects: 100% (65/65), done.[K
remote: Total 301 (delta 92), reused 97 (delta 70), pack-reused 164[K
Receiving objects: 100% (301/301), 4.79 MiB | 14.04 MiB/s, done.
Resolving deltas: 100% (121/121), done.


In [2]:
!pip -qqq install -r ./latr/requirements.txt

In [3]:
!sudo apt install -qqq tesseract-ocr

In [4]:
## Default Library import

import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import json
from tqdm.auto import tqdm
import pandas as pd

from transformers import AutoTokenizer, AutoConfig, AutoProcessor
from transformers import T5ForConditionalGeneration, ViTModel
import torch.nn as nn
import torch

from torch.utils.data import DataLoader

## Setting up the device for GPU usage
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
import pytorch_lightning as pl

In [5]:
import sys
sys.path.append("./latr/src/new_latr/")

from dataset import TextVQA
from utils import collate, draw_bounding_box_on_pil_image

In [6]:
## Setting the hyperparameters as well as primary configurations

PAD_TOKEN_BOX = [0, 0, 0, 0]
QUESTION_BOX = [0, 0, 0, 0]
EOS_BOX = [0, 0, 0, 0]

batch_size = 2
target_size = (224,224)
t5_model = "t5-base"

In [7]:
model_name = 't5-base'
model_config = AutoConfig.from_pretrained(model_name)

max_2d_position_embeddings = 1024
vit_model = "google/vit-base-patch16-224-in21k"
model_config.update({"max_2d_position_embeddings" : max_2d_position_embeddings,
                    "vit_model" : vit_model})

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = True)
processor = AutoProcessor.from_pretrained(vit_model)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


Downloading (…)rocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

In [8]:
base_path = '/kaggle/input/new-textvqa-dataset-mine'
ocr_json_path = os.path.join(base_path, 'TextVQA_Rosetta_OCR_v0.2_train.json')
train_json_path = os.path.join(base_path, 'TextVQA_0.5.1_train.json')

val_ocr_json_path = os.path.join(base_path, 'TextVQA_Rosetta_OCR_v0.2_val.json')
val_json_path = os.path.join(base_path, 'TextVQA_0.5.1_val.json')

In [9]:
with open(ocr_json_path) as f:
    train_ocr_json = json.load(f)['data']
with open(train_json_path) as f:
    train_json = json.load(f)['data']
    
## Validation
with open(val_ocr_json_path) as f:
    val_ocr_json = json.load(f)['data']
with open(val_json_path) as f:
    val_json = json.load(f)['data']

In [10]:
## Useful for the key-value extraction

train_json_df = pd.DataFrame(train_json)
train_ocr_json_df = pd.DataFrame(train_ocr_json)

val_json_df = pd.DataFrame(val_json)
val_ocr_json_df = pd.DataFrame(val_ocr_json)

In [11]:
train_json_df.drop(columns = ['flickr_original_url', 'flickr_300k_url','image_classes', 'question_tokens',# 'path_exists'
                              ], axis = 1, inplace = True)

val_json_df.drop(columns = ['flickr_original_url', 'flickr_300k_url','image_classes', 'question_tokens',# 'path_exists'
                              ], axis = 1, inplace = True)
## Deleting the json

del train_json
del train_ocr_json
del val_json
del val_ocr_json

In [12]:
base_img_path = os.path.join(base_path, 'train_val_images', 'train_images')

In [13]:
max_seq_len = -1

In [14]:
train_ds = TextVQA(base_img_path = base_img_path,
                   json_df = train_json_df,
                   ocr_json_df = train_ocr_json_df,
                   tokenizer = tokenizer,
                   transform = processor, 
                   max_seq_length = max_seq_len, 
                   )

val_ds = TextVQA(base_img_path = base_img_path,
                   json_df = val_json_df,
                   ocr_json_df = val_ocr_json_df,
                   tokenizer = tokenizer,
                   transform = processor, 
                   max_seq_length = max_seq_len, 
                   )

In [15]:
# encoding = train_ds[500]
# print(tokenizer.decode(encoding['input_ids'], skip_special_tokens = True))
# print(tokenizer.decode(encoding['labels'], skip_special_tokens = True))

In [16]:
# from torchvision.transforms import ToPILImage
# pil_image = ToPILImage()(encoding['pixel_values']).resize((1000, 1000))
# visualized_pil_image = draw_bounding_box_on_pil_image(pil_image, encoding['bbox'], outline = 'red')

In [17]:
# first_sample = train_ds[22]
# second_sample = train_ds[25]

# batch_encoding = collate([first_sample, second_sample])

# for key in batch_encoding:
#     print(f"Key : {key}, has shape {batch_encoding[key].shape}")

In [18]:
class DataModule(pl.LightningDataModule):

  def __init__(self, train_dataset, val_dataset,  batch_size = 2):

    super(DataModule, self).__init__()
    self.train_dataset = train_dataset
    self.val_dataset = val_dataset
    self.batch_size = batch_size

  def train_dataloader(self):
    return DataLoader(self.train_dataset, batch_size = self.batch_size, 
                      collate_fn = collate, shuffle = True)
  
  def val_dataloader(self):
    return DataLoader(self.val_dataset, batch_size = self.batch_size,
                                  collate_fn = collate, shuffle = False)


In [19]:
dl = DataModule(train_ds, val_ds)

In [20]:
sample = next(iter(dl.train_dataloader()))

In [21]:
for key in sample:
    print(f"Key : {key}, has shape : {sample[key].shape}")

Key : img, has shape : torch.Size([2, 3, 224, 224])
Key : bbox, has shape : torch.Size([2, 124, 6])
Key : input_ids, has shape : torch.Size([2, 124])
Key : labels, has shape : torch.Size([2, 3])
Key : attention_mask, has shape : torch.Size([2, 124])


In [22]:
class SpatialModule(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.top_left_x = nn.Embedding(config.max_2d_position_embeddings, config.d_model)
        self.bottom_right_x = nn.Embedding(config.max_2d_position_embeddings, config.d_model)
        self.top_left_y = nn.Embedding(config.max_2d_position_embeddings, config.d_model)
        self.bottom_right_y = nn.Embedding(config.max_2d_position_embeddings, config.d_model)
        self.width_emb = nn.Embedding(config.max_2d_position_embeddings, config.d_model)
        self.height_emb = nn.Embedding(config.max_2d_position_embeddings, config.d_model)
        
    def forward(self, coordinates):
        
        top_left_x_feat =     self.top_left_x(coordinates[:,:, 0])
        top_left_y_feat =     self.top_left_y(coordinates[:,:, 1])
        bottom_right_x_feat = self.bottom_right_x(coordinates[:,:, 2])
        bottom_right_y_feat = self.bottom_right_y(coordinates[:,:, 3])
        width_feat =          self.width_emb(coordinates[:,:, 4])
        height_feat =         self.height_emb(coordinates[:,:, 5])
        
        layout_feature = top_left_x_feat + top_left_y_feat + bottom_right_x_feat + bottom_right_y_feat + width_feat + height_feat
        return layout_feature

class LaTrForConditionalGeneration(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.spatial_feat_extractor = SpatialModule(config)
        self.img_feat_extractor = ViTModel.from_pretrained(config.vit_model)
        self.t5_model = T5ForConditionalGeneration.from_pretrained(config._name_or_path)
        
    def forward(self, batch):
        img_feat = self.img_feat_extractor(batch['img']).last_hidden_state
        spatial_feat = self.spatial_feat_extractor(batch['bbox'])
        language_feat = self.t5_model.shared(batch['input_ids'])
        
        layout_feat = spatial_feat + language_feat
        multi_modal_feat = torch.cat([img_feat, layout_feat], axis = 1)
        return self.t5_model(inputs_embeds = multi_modal_feat, labels = batch['labels'], output_hidden_states = False)

In [23]:
latr_model = LaTrForConditionalGeneration(model_config)

Downloading pytorch_model.bin:   0%|          | 0.00/346M [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/892M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [24]:
output = latr_model(sample)

## Code for writing the generate function, since we need to predict the answers now. 