In [1]:
from utils import convert_ans_to_token, convert_ques_to_token, rotate, convert_token_to_ques, convert_token_to_answer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from dataset import load_json_file, get_specific_file, resize_align_bbox, get_tokens_with_boxes, create_features

In [3]:
import os
import json
import pandas as pd
import torch
from torchvision import transforms
import sys

In [4]:
from tqdm.auto import tqdm

In [5]:
tqdm.pandas()

In [6]:
PAD_TOKEN_BOX = [0, 0, 0, 0]
max_seq_len = 512
batch_size = 2
target_size = (500,384) ## Note that, ViT would make it 224x224 so :(

In [7]:
base_path = "../../textvqa_eval/"

train_ocr_json_path = os.path.join(base_path, 'TextVQA_Rosetta_OCR_v0.2_train.json')
train_json_path = os.path.join(base_path, 'TextVQA_0.5.1_train.json')

val_ocr_json_path = os.path.join(base_path, 'TextVQA_Rosetta_OCR_v0.2_val.json')
val_json_path = os.path.join(base_path, 'TextVQA_0.5.1_val.json')

In [8]:
train_ocr_json = json.load(open(train_ocr_json_path))['data']
train_json = json.load(open(train_json_path))['data']

val_ocr_json = json.load(open(val_ocr_json_path))['data']
val_json = json.load(open(val_json_path))['data']

In [9]:
train_json_df = pd.DataFrame(train_json)
train_ocr_json_df = pd.DataFrame(train_ocr_json)

val_json_df = pd.DataFrame(val_json)
val_ocr_json_df = pd.DataFrame(val_ocr_json)

In [10]:
train_json_df['answers'] = train_json_df['answers'].apply(lambda x: " ".join(list(map(str, x))))
val_json_df['answers']   = val_json_df['answers'].apply(lambda x: " ".join(list(map(str, x))))

In [11]:
## Dropping of the images which doesn't exist, might take some time

base_img_path = os.path.join(base_path, 'train_images')

train_json_df['path_exists'] = train_json_df['image_id'].progress_apply(lambda x: os.path.exists(os.path.join(base_img_path, x)+'.jpg'))
train_json_df = train_json_df[train_json_df['path_exists']==True]

val_json_df['path_exists'] = val_json_df['image_id'].progress_apply(lambda x: os.path.exists(os.path.join(base_img_path, x)+'.jpg'))
val_json_df = val_json_df[val_json_df['path_exists']==True]

100%|██████████| 34602/34602 [00:06<00:00, 5289.37it/s]
100%|██████████| 5000/5000 [00:00<00:00, 5466.71it/s]


In [12]:
## Dropping the unused columns

train_json_df.drop(columns = ['flickr_original_url', 'flickr_300k_url','image_classes', 'question_tokens', 'path_exists'
                              ], axis = 1, inplace = True)
val_json_df.drop(columns = ['flickr_original_url', 'flickr_300k_url','image_classes', 'question_tokens', 'path_exists'
                              ], axis = 1, inplace = True)

In [13]:
## Deleting the json

del train_json
del train_ocr_json
del val_json
del val_ocr_json

In [14]:
## Grouping for the purpose of feature extraction
grouped_df = train_json_df.groupby('image_id')

## Getting all the unique keys of the group by object
keys = list(grouped_df.groups.keys())

In [15]:
## Create dataset class for TextVQA

class TextVqaDataset(torch.utils.data.Dataset):
    def __init__(self, base_img_path, json_df, ocr_json_df, tokenizer, 
    transform = None, max_seq_length = 100, target_size = (500,384), fine_tune = True):
        self.base_img_path = base_img_path
        self.json_df = json_df
        self.ocr_json_df = ocr_json_df
        self.tokenizer = tokenizer
        self.target_size = target_size
        self.transform = transform
        self.max_seq_length = max_seq_length
        self.fine_tune = fine_tune

    def __len__(self):
        return len(self.json_df)

    def __getitem__(self, idx):
        curr_img = self.json_df.iloc[idx]['image_id']
        ocr_token = self.ocr_json_df[self.ocr_json_df['image_id']==curr_img]['ocr_info'].values.tolist()[0]

        boxes = []
        words = []

        current_group = self.json_df.iloc[idx]
        width, height = current_group['image_width'], current_group['image_height']

        for entry in ocr_token:
            xmin, ymin, w, h, angle = entry['bounding_box']['top_left_x'], entry['bounding_box']['top_left_y'],  entry['bounding_box']['width'],  entry['bounding_box']['height'], entry['bounding_box']['rotation']
            xmin, ymin,w, h = resize_align_bbox([xmin, ymin, w, h], 1, 1, width, height)
            
            x_centre = xmin + (w/2)
            y_centre = ymin + (h/2)

            ## print("The angle is:", angle)
            xmin, ymin = rotate([x_centre, y_centre], [xmin, ymin], angle)

            xmax = xmin + w
            ymax = ymin + h

            ## Bounding boxes are normalized
            curr_bbox = [xmin, ymin, xmax, ymax]
            boxes.append(curr_bbox)
            words.append(entry['word'])

        img_path = os.path.join(self.base_img_path, curr_img)+'.jpg'
        assert os.path.exists(img_path)==True, f'Make sure that the image exists at {img_path}!!'

        if self.fine_tune:
            ## For fine-tune stage, they use [0, 0, 1000, 1000] for all the bounding box
            img = Image.open(img_path).convert("RGB")
            img = img.resize(self.target_size)
            boxes = torch.zeros(self.max_seq_length, 4)
            boxes[:, 2] = 1000
            boxes[:, 3] = 1000
            
            words = " ".join(words)
            tokenized_words = self.tokenizer.encode(words, max_length = self.max_seq_length, 
                truncation = True, padding = 'max_length', return_tensors = 'pt')[0]
        else:
            raise NotImplementedError("Flow for `self.fine_tune != False` is not defined!")

        ## Converting the boxes as per the format required for model input
        boxes = torch.as_tensor(boxes, dtype=torch.int32)
        width = (boxes[:, 2] - boxes[:, 0]).view(-1, 1)
        height = (boxes[:, 3] - boxes[:, 1]).view(-1, 1)
        boxes = torch.cat([boxes, width, height], axis = -1)

        ## Clamping the value,as some of the box values are out of bound
        boxes[:, 0] = torch.clamp(boxes[:, 0], min = 0, max = 1000)
        boxes[:, 2] = torch.clamp(boxes[:, 2], min = 0, max = 1000)
        boxes[:, 4] = torch.clamp(boxes[:, 4], min = 0, max = 1000)
        
        boxes[:, 1] = torch.clamp(boxes[:, 1], min = 0, max = 1000)
        boxes[:, 3] = torch.clamp(boxes[:, 3], min = 0, max = 1000)
        boxes[:, 5] = torch.clamp(boxes[:, 5], min = 0, max = 1000)
        
        ## Tensor tokenized words
        tokenized_words = torch.as_tensor(tokenized_words, dtype=torch.int32)

        if self.transform is not None:
            img = self.transform(img)
        else:
            img = transforms.ToTensor()(img)


        ## Getting the Question
        question = current_group['question']   
        question = convert_ques_to_token(question = question, tokenizer = self.tokenizer)

        ## Getting the Answer
        answer = current_group['answers']
        answer = convert_ques_to_token(question = answer, tokenizer = self.tokenizer).long()

        return {'img':img, 'boxes': boxes, 'tokenized_words': tokenized_words, 
                'question': question, 'answer': answer, 'id': torch.as_tensor(idx)}

In [20]:
from transformers import BertTokenizer

In [21]:
def get_pretrained_tokenizer(from_pretrained):
    if torch.distributed.is_initialized():
        if torch.distributed.get_rank() == 0:
            BertTokenizer.from_pretrained(
                from_pretrained, do_lower_case="uncased" in from_pretrained
            )
        torch.distributed.barrier()
    return BertTokenizer.from_pretrained(
        from_pretrained, do_lower_case="uncased" in from_pretrained
    )

In [27]:
tokenizer = get_pretrained_tokenizer("bert-base-uncased")

train_ds = TextVqaDataset(base_img_path = base_img_path,
                         json_df = train_json_df,
                         ocr_json_df = train_ocr_json_df,
                         tokenizer = tokenizer,
                         transform = None, 
                         max_seq_length = max_seq_len, 
                         target_size = target_size
                         )


val_ds = TextVqaDataset(base_img_path = base_img_path,
                        json_df = val_json_df,
                        ocr_json_df = val_ocr_json_df,
                        tokenizer = tokenizer,
                        transform = None, 
                        max_seq_length = max_seq_len, 
                        target_size = target_size
                        )

In [28]:
import pytorch_lightning as pl

In [32]:
# sys.path.append("../vilt/transforms/")

In [37]:
# import pixelbert.pixelbert_transform
from ..vilt.transforms.pixelbert import pixelbert_transform

ImportError: attempted relative import with no known parent package

In [29]:
pb_transform = 

def collate_fn(data_bunch):
    '''
    A function for the dataloader to return a batch dict of given keys

    data_bunch: List of dictionary
    '''
    dict_data_bunch = {}

    for i in data_bunch:
        for (key, value) in i.items():
            if key not in dict_data_bunch:
                dict_data_bunch[key] = []
            dict_data_bunch[key].append(value)

    for key in list(dict_data_bunch.keys()):
        dict_data_bunch[key] = torch.stack(dict_data_bunch[key], axis = 0)

    if 'img' in dict_data_bunch:
        ## Pre-processing for ViT
        dict_data_bunch['img'] = vit_feat_extract(list(dict_data_bunch['img']),return_tensors = 'pt')['pixel_values']

    return dict_data_bunch

In [31]:
class DataModule(pl.LightningDataModule):
    def __init__(self, train_dataset, val_dataset,  batch_size = 32):
        super(DataModule, self).__init__()
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.batch_size = batch_size

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size = self.batch_size, 
                    collate_fn = collate_fn, shuffle = True, num_workers = 2, pin_memory = True, persistent_workers = True)
  
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size = self.batch_size, 
                    collate_fn = collate_fn, shuffle = False, num_workers = 2, pin_memory = True, persistent_workers = True)