In [1]:
!git clone https://github.com/uakarsh/latr.git

Cloning into 'latr'...
remote: Enumerating objects: 331, done.[K
remote: Counting objects: 100% (167/167), done.[K
remote: Compressing objects: 100% (88/88), done.[K
remote: Total 331 (delta 107), reused 116 (delta 77), pack-reused 164[K
Receiving objects: 100% (331/331), 4.80 MiB | 2.92 MiB/s, done.
Resolving deltas: 100% (136/136), done.


In [2]:
!pip -qqq install -r ./latr/requirements.txt

In [3]:
!sudo apt install -qqq tesseract-ocr

In [4]:
## Default Library import

import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

import json
from tqdm.auto import tqdm
import pandas as pd

from transformers import AutoTokenizer, AutoConfig, AutoProcessor
from transformers import T5ForConditionalGeneration, ViTModel
import torch.nn as nn
import torch

from torch.utils.data import DataLoader

## Setting up the device for GPU usage
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
import pytorch_lightning as pl

In [5]:
import sys
sys.path.append("./latr/src/new_latr/")

from dataset import TextVQA
from utils import collate, draw_bounding_box_on_pil_image

In [6]:
## Setting the hyperparameters as well as primary configurations

PAD_TOKEN_BOX = [0, 0, 0, 0]
QUESTION_BOX = [0, 0, 0, 0]
EOS_BOX = [0, 0, 0, 0]

batch_size = 2
target_size = (224,224)
t5_model = "t5-base"

In [7]:
model_name = 't5-base'
model_config = AutoConfig.from_pretrained(model_name)

max_2d_position_embeddings = 1024
vit_model = "google/vit-base-patch16-224-in21k"
model_config.update({"max_2d_position_embeddings" : max_2d_position_embeddings,
                    "vit_model" : vit_model})

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = True)
processor = AutoProcessor.from_pretrained(vit_model)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


Downloading (…)rocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

In [8]:
base_path = '/kaggle/input/new-textvqa-dataset-mine'
ocr_json_path = os.path.join(base_path, 'TextVQA_Rosetta_OCR_v0.2_train.json')
train_json_path = os.path.join(base_path, 'TextVQA_0.5.1_train.json')

val_ocr_json_path = os.path.join(base_path, 'TextVQA_Rosetta_OCR_v0.2_val.json')
val_json_path = os.path.join(base_path, 'TextVQA_0.5.1_val.json')

In [9]:
with open(ocr_json_path) as f:
    train_ocr_json = json.load(f)['data']
with open(train_json_path) as f:
    train_json = json.load(f)['data']
    
## Validation
with open(val_ocr_json_path) as f:
    val_ocr_json = json.load(f)['data']
with open(val_json_path) as f:
    val_json = json.load(f)['data']

In [10]:
## Useful for the key-value extraction

train_json_df = pd.DataFrame(train_json)
train_ocr_json_df = pd.DataFrame(train_ocr_json)

val_json_df = pd.DataFrame(val_json)
val_ocr_json_df = pd.DataFrame(val_ocr_json)

In [11]:
train_json_df.drop(columns = ['flickr_original_url', 'flickr_300k_url','image_classes', 'question_tokens',# 'path_exists'
                              ], axis = 1, inplace = True)

val_json_df.drop(columns = ['flickr_original_url', 'flickr_300k_url','image_classes', 'question_tokens',# 'path_exists'
                              ], axis = 1, inplace = True)
## Deleting the json

del train_json
del train_ocr_json
del val_json
del val_ocr_json

In [12]:
base_img_path = os.path.join(base_path, 'train_val_images', 'train_images')

In [13]:
max_seq_len = -1

In [14]:
train_ds = TextVQA(base_img_path = base_img_path,
                   json_df = train_json_df,
                   ocr_json_df = train_ocr_json_df,
                   tokenizer = tokenizer,
                   transform = processor, 
                   max_seq_length = max_seq_len, 
                   )

val_ds = TextVQA(base_img_path = base_img_path,
                   json_df = val_json_df,
                   ocr_json_df = val_ocr_json_df,
                   tokenizer = tokenizer,
                   transform = processor, 
                   max_seq_length = max_seq_len, 
                   )

In [15]:
# encoding = train_ds[500]
# print(tokenizer.decode(encoding['input_ids'], skip_special_tokens = True))
# print(tokenizer.decode(encoding['labels'], skip_special_tokens = True))

In [16]:
# from torchvision.transforms import ToPILImage
# pil_image = ToPILImage()(encoding['pixel_values']).resize((1000, 1000))
# visualized_pil_image = draw_bounding_box_on_pil_image(pil_image, encoding['bbox'], outline = 'red')

In [17]:
# first_sample = train_ds[22]
# second_sample = train_ds[25]

# batch_encoding = collate([first_sample, second_sample])

# for key in batch_encoding:
#     print(f"Key : {key}, has shape {batch_encoding[key].shape}")

In [18]:
class DataModule(pl.LightningDataModule):

  def __init__(self, train_dataset, val_dataset,  batch_size = 1):

    super(DataModule, self).__init__()
    self.train_dataset = train_dataset
    self.val_dataset = val_dataset
    self.batch_size = batch_size

  def train_dataloader(self):
    return DataLoader(self.train_dataset, batch_size = self.batch_size, 
                      collate_fn = collate, shuffle = True)
  
  def val_dataloader(self):
    return DataLoader(self.val_dataset, batch_size = self.batch_size,
                                  collate_fn = collate, shuffle = False)


In [19]:
dl = DataModule(train_ds, val_ds)

In [20]:
sample = next(iter(dl.train_dataloader()))

In [21]:
for key in sample:
    print(f"Key : {key}, has shape : {sample[key].shape}")
    sample[key] = sample[key].to(device)

Key : img, has shape : torch.Size([1, 3, 224, 224])
Key : bbox, has shape : torch.Size([1, 17, 6])
Key : input_ids, has shape : torch.Size([1, 17])
Key : labels, has shape : torch.Size([1, 4])
Key : attention_mask, has shape : torch.Size([1, 17])


In [22]:
from transformers.modeling_outputs import (
    BaseModelOutput,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
)
import torch.nn as nn
from torch.nn import CrossEntropyLoss

class SpatialModule(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.top_left_x = nn.Embedding(config.max_2d_position_embeddings, config.d_model)
        self.bottom_right_x = nn.Embedding(config.max_2d_position_embeddings, config.d_model)
        self.top_left_y = nn.Embedding(config.max_2d_position_embeddings, config.d_model)
        self.bottom_right_y = nn.Embedding(config.max_2d_position_embeddings, config.d_model)
        self.width_emb = nn.Embedding(config.max_2d_position_embeddings, config.d_model)
        self.height_emb = nn.Embedding(config.max_2d_position_embeddings, config.d_model)
        
    def forward(self, coordinates):
        
        top_left_x_feat =     self.top_left_x(coordinates[:,:, 0])
        top_left_y_feat =     self.top_left_y(coordinates[:,:, 1])
        bottom_right_x_feat = self.bottom_right_x(coordinates[:,:, 2])
        bottom_right_y_feat = self.bottom_right_y(coordinates[:,:, 3])
        width_feat =          self.width_emb(coordinates[:,:, 4])
        height_feat =         self.height_emb(coordinates[:,:, 5])
        
        layout_feature = top_left_x_feat + top_left_y_feat + bottom_right_x_feat + bottom_right_y_feat + width_feat + height_feat
        return layout_feature

class LaTrForConditionalGeneration(T5ForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config = config)
        self.spatial_feat_extractor = SpatialModule(config)
        self.img_feat_extractor = ViTModel.from_pretrained(config.vit_model)
#         self.t5_model = T5ForConditionalGeneration.from_pretrained(config._name_or_path)
        
    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def get_output_embeddings(self):
        return self.lm_head

    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder
    
    def forward(
        self,
        input_ids = None,
        bbox = None,
        attention_mask = None,
        decoder_input_ids = None,
        decoder_attention_mask = None,
        encoder_outputs = None,
        past_key_values = None,
        pixel_values = None,
        visual_bbox = None,
        labels = None,
        head_mask = None,
        inputs_embeds = None,
        decoder_inputs_embeds = None,
        decoder_head_mask = None,
        cross_attn_head_mask = None,
        use_cache=True,
        output_attentions = None,
        output_hidden_states = None,
        return_dict = None,
        **kwargs,) :

        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if decoder_input_ids is None and labels is not None:
            decoder_input_ids = self._shift_right(labels)

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            inputs_embeds, attention_mask = self.calculate_embedding(pixel_values, bbox, input_ids, attention_mask)
            encoder_outputs = self.encoder(
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        hidden_states = encoder_outputs[0]

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.config.d_model**-0.5)

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + decoder_outputs[2:] + (encoder_outputs[0],) + encoder_outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # cut decoder_input_ids if past is used
        if past_key_values is not None:
            input_ids = input_ids[:, -1:]

        return {
            "decoder_input_ids": input_ids,
            "past_key_values": past_key_values,
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,
            "bbox": kwargs.get("bbox", None),
            "pixel_values": kwargs.get("pixel_values", None),
            "visual_bbox": kwargs.get("visual_bbox", None),
        }

    def _reorder_cache(self, past_key_values, beam_idx):
        # if decoder past is not included in output
        # speedy decoding is disabled and no need to reorder
        if past_key_values is None:
            logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
            return past_key_values

        reordered_decoder_past = ()
        for layer_past_states in past_key_values:
            # get the correct batch idx from layer past batch dim
            # batch dim of `past` is at 2nd position
            reordered_layer_past_states = ()
            for layer_past_state in layer_past_states:
                # need to set correct `past` for each of the four key / value states
                reordered_layer_past_states = reordered_layer_past_states + (
                    layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
                )

            assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
            assert len(reordered_layer_past_states) == len(layer_past_states)

            reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
        return reordered_decoder_past
    
    def calculate_embedding(self, img, bbox, input_ids, attention_mask):
        img_feat = self.img_feat_extractor(img).last_hidden_state
        spatial_feat = self.spatial_feat_extractor(bbox)
        language_feat = self.shared(input_ids)
        
        layout_feat = spatial_feat + language_feat
        multi_modal_feat = torch.cat([img_feat, layout_feat], axis = 1)
        input_attention_mask = torch.cat([torch.ones(img_feat.shape[:2]).to(img_feat.device), attention_mask], axis = 1)
        return multi_modal_feat, input_attention_mask

In [23]:
latr_model = LaTrForConditionalGeneration(model_config).to(device)

Downloading pytorch_model.bin:   0%|          | 0.00/346M [00:00<?, ?B/s]

In [24]:
output = latr_model(input_ids = sample['input_ids'],
                   bbox = sample['bbox'], pixel_values = sample['img'], labels = sample['labels'],
                   attention_mask = sample['attention_mask'])

In [25]:
tokenizer.decode(latr_model.generate(input_ids = sample['input_ids'],
                   bbox = sample['bbox'], pixel_values = sample['img'], labels = sample['labels'],
                   attention_mask = sample['attention_mask'])[0], skip_special_tokens = True)



''

In [26]:
tokenizer.decode(sample['input_ids'][0], skip_special_tokens = True)

'question: what color are the letters on her hat? context: EC'

In [27]:
# !git clone https://github.com/yashkant/sam-textvqa.git

In [28]:
# !pip install -r ./sam-textvqa/requirements.txt

In [29]:
## Ref: https://github.com/yashkant/sam-textvqa/blob/main/sam/datasets/metrics.py#L305

import re

class EvalAIAnswerProcessor:
    """
    Processes an answer similar to Eval AI
        copied from
        https://github.com/facebookresearch/pythia/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
    """

    CONTRACTIONS = {
        "aint": "ain't",
        "arent": "aren't",
        "cant": "can't",
        "couldve": "could've",
        "couldnt": "couldn't",
        "couldn'tve": "couldn't've",
        "couldnt've": "couldn't've",
        "didnt": "didn't",
        "doesnt": "doesn't",
        "dont": "don't",
        "hadnt": "hadn't",
        "hadnt've": "hadn't've",
        "hadn'tve": "hadn't've",
        "hasnt": "hasn't",
        "havent": "haven't",
        "hed": "he'd",
        "hed've": "he'd've",
        "he'dve": "he'd've",
        "hes": "he's",
        "howd": "how'd",
        "howll": "how'll",
        "hows": "how's",
        "Id've": "I'd've",
        "I'dve": "I'd've",
        "Im": "I'm",
        "Ive": "I've",
        "isnt": "isn't",
        "itd": "it'd",
        "itd've": "it'd've",
        "it'dve": "it'd've",
        "itll": "it'll",
        "let's": "let's",
        "maam": "ma'am",
        "mightnt": "mightn't",
        "mightnt've": "mightn't've",
        "mightn'tve": "mightn't've",
        "mightve": "might've",
        "mustnt": "mustn't",
        "mustve": "must've",
        "neednt": "needn't",
        "notve": "not've",
        "oclock": "o'clock",
        "oughtnt": "oughtn't",
        "ow's'at": "'ow's'at",
        "'ows'at": "'ow's'at",
        "'ow'sat": "'ow's'at",
        "shant": "shan't",
        "shed've": "she'd've",
        "she'dve": "she'd've",
        "she's": "she's",
        "shouldve": "should've",
        "shouldnt": "shouldn't",
        "shouldnt've": "shouldn't've",
        "shouldn'tve": "shouldn't've",
        "somebody'd": "somebodyd",
        "somebodyd've": "somebody'd've",
        "somebody'dve": "somebody'd've",
        "somebodyll": "somebody'll",
        "somebodys": "somebody's",
        "someoned": "someone'd",
        "someoned've": "someone'd've",
        "someone'dve": "someone'd've",
        "someonell": "someone'll",
        "someones": "someone's",
        "somethingd": "something'd",
        "somethingd've": "something'd've",
        "something'dve": "something'd've",
        "somethingll": "something'll",
        "thats": "that's",
        "thered": "there'd",
        "thered've": "there'd've",
        "there'dve": "there'd've",
        "therere": "there're",
        "theres": "there's",
        "theyd": "they'd",
        "theyd've": "they'd've",
        "they'dve": "they'd've",
        "theyll": "they'll",
        "theyre": "they're",
        "theyve": "they've",
        "twas": "'twas",
        "wasnt": "wasn't",
        "wed've": "we'd've",
        "we'dve": "we'd've",
        "weve": "we've",
        "werent": "weren't",
        "whatll": "what'll",
        "whatre": "what're",
        "whats": "what's",
        "whatve": "what've",
        "whens": "when's",
        "whered": "where'd",
        "wheres": "where's",
        "whereve": "where've",
        "whod": "who'd",
        "whod've": "who'd've",
        "who'dve": "who'd've",
        "wholl": "who'll",
        "whos": "who's",
        "whove": "who've",
        "whyll": "why'll",
        "whyre": "why're",
        "whys": "why's",
        "wont": "won't",
        "wouldve": "would've",
        "wouldnt": "wouldn't",
        "wouldnt've": "wouldn't've",
        "wouldn'tve": "wouldn't've",
        "yall": "y'all",
        "yall'll": "y'all'll",
        "y'allll": "y'all'll",
        "yall'd've": "y'all'd've",
        "y'alld've": "y'all'd've",
        "y'all'dve": "y'all'd've",
        "youd": "you'd",
        "youd've": "you'd've",
        "you'dve": "you'd've",
        "youll": "you'll",
        "youre": "you're",
        "youve": "you've",
    }

    NUMBER_MAP = {
        "none": "0",
        "zero": "0",
        "one": "1",
        "two": "2",
        "three": "3",
        "four": "4",
        "five": "5",
        "six": "6",
        "seven": "7",
        "eight": "8",
        "nine": "9",
        "ten": "10",
    }
    ARTICLES = ["a", "an", "the"]
    PERIOD_STRIP = re.compile("(?!<=\d)(\.)(?!\d)")
    COMMA_STRIP = re.compile("(?<=\d)(\,)+(?=\d)")
    PUNCTUATIONS = [
        ";",
        r"/",
        "[",
        "]",
        '"',
        "{",
        "}",
        "(",
        ")",
        "=",
        "+",
        "\\",
        "_",
        "-",
        ">",
        "<",
        "@",
        "`",
        ",",
        "?",
        "!",
    ]

    def __init__(self, *args, **kwargs):
        pass

    def word_tokenize(self, word):
        word = word.lower()
        word = word.replace(",", "").replace("?", "").replace("'s", " 's")
        return word.strip()

    def process_punctuation(self, in_text):
        out_text = in_text
        for p in self.PUNCTUATIONS:
            if (p + " " in in_text or " " + p in in_text) or (
                re.search(self.COMMA_STRIP, in_text) is not None
            ):
                out_text = out_text.replace(p, "")
            else:
                out_text = out_text.replace(p, " ")
        out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE)
        return out_text

    def process_digit_article(self, in_text):
        out_text = []
        temp_text = in_text.lower().split()
        for word in temp_text:
            word = self.NUMBER_MAP.setdefault(word, word)
            if word not in self.ARTICLES:
                out_text.append(word)
            else:
                pass
        for word_id, word in enumerate(out_text):
            if word in self.CONTRACTIONS:
                out_text[word_id] = self.CONTRACTIONS[word]
        out_text = " ".join(out_text)
        return out_text

    def __call__(self, item):
        item = self.word_tokenize(item)
        item = item.replace("\n", " ").replace("\t", " ").strip()
        item = self.process_punctuation(item)
        item = self.process_digit_article(item)
        return item


class TextVQAAccuracyEvaluator:
    def __init__(self):
        self.answer_processor = EvalAIAnswerProcessor()

    def _compute_answer_scores(self, raw_answers):
        """
        compute the accuracy (soft score) of human answers
        """
        answers = [self.answer_processor(a) for a in raw_answers]
        print(f"Answers are : {answers}")
        #assert len(answers) == 10
        gt_answers = list(enumerate(answers))
        unique_answers = set(answers)
        unique_answer_scores = {}

        for unique_answer in unique_answers:
            accs = []
            for gt_answer in gt_answers:
                other_answers = [item for item in gt_answers if item != gt_answer]
                matching_answers = [
                    item for item in other_answers if item[1] == unique_answer
                ]
                acc = min(1, float(len(matching_answers)) / 3)
                accs.append(acc)
            unique_answer_scores[unique_answer] = sum(accs) / len(accs)

        return unique_answer_scores

    def eval_pred_list(self, pred_list):
        pred_scores = []
        for entry in pred_list:
            pred_answer = self.answer_processor(entry["pred_answer"])
            print(f"Pred answer is : {pred_answer}")
            unique_answer_scores = self._compute_answer_scores(entry["gt_answers"])
            score = unique_answer_scores.get(pred_answer, 0.0)
            pred_scores.append(score)

        accuracy = sum(pred_scores) / len(pred_scores)
        return accuracy, pred_scores

In [30]:
metric = TextVQAAccuracyEvaluator()

In [31]:
labels = tokenizer.batch_decode(sample['labels'], skip_special_tokens = True)

In [32]:
current_prediction = []
current_prediction.append({
                "gt_answers": labels,
                "pred_answer": labels[0],
            })

In [33]:
accuracy, pred_scores = metric.eval_pred_list(current_prediction)

Pred answer is : ec
Answers are : ['ec']


In [34]:
accuracy

0.0