In [1]:
from datasets import load_from_disk
import torch
from utils_modeling import get_tokens_with_boxes, normalize_box
# from transformers.utils import is_torch_fx_proxy
from modeling import PatchEmbeddings, SpatialModule, DocFormerV2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
raw_datasets = load_from_disk("../data/idl-pretrain-dataset")

In [3]:
# # create rectangle image
# img = item['img']
# bbox = item['bbox'][:-1] ## Removing the website
# words = item['words'][:-1] ## Removing the website
# draw_on_img = ImageDraw.Draw(img)  

# for it in bbox:
#   draw_on_img.rectangle(it, outline ="violet")
# img

In [4]:
# ## Pre-processing bounding boxes, ref: https://github.com/uakarsh/latr/blob/1e73c1a99f9a0db4d85177259226148a65556069/src/new_latr/dataset.py#L34

# def normalize_box(box, width, height, size=1000):
#     """
#     Takes a bounding box and normalizes it to a thousand pixels. If you notice it is
#     just like calculating percentage except takes 1000 instead of 100.

#     Arguments:
#         box: A list of bounding box coordinates
#         width: The width of the image
#         height: The height of the image
#         size: The size to normalize to
#     Returns:
#         A list of normalized bounding box coordinates
#     """
#     return [
#         int(size * (box[0] / width)),
#         int(size * (box[1] / height)),
#         int(size * (box[2] / width)),
#         int(size * (box[3] / height)),
#     ]

# def get_tokens_with_boxes(bounding_boxes, list_of_words, tokenizer, pad_token_box=[0, 0, 0, 0], max_seq_len=-1, eos_token_box=[0, 0, 1000, 1000]):

#     '''
#     A function to get the tokens with the bounding boxes
#     Arguments:
#         bounding_boxes: A list of bounding boxes
#         list_of_words: A list of words
#         tokenizer: The tokenizer to use
#         pad_token_box: The padding token box
#         max_seq_len: The maximum sequence length, not padded if max_seq_len is -1
#         eos_token_box: The end of sequence token box
#     Returns:
#         A list of input_ids, bbox_according_to_tokenizer, attention_mask
#     '''

#     # 2. Performing the semantic pre-processing
#     encoding = tokenizer(list_of_words, is_split_into_words=True,
#                          add_special_tokens=False)

#     input_ids = encoding['input_ids']
#     attention_mask = encoding['attention_mask']

#     # Note that, there is no need for bboxes, since the model does not use bbox as feature, so no pre-processing of that
#     bbox_according_to_tokenizer = [bounding_boxes[i]
#                                    for i in encoding.word_ids()]

#     # Truncation of token_boxes + token_labels
#     special_tokens_count = 1
#     if max_seq_len != -1 and len(input_ids) > max_seq_len - special_tokens_count:
#         bbox_according_to_tokenizer = bbox_according_to_tokenizer[: (
#             max_seq_len - special_tokens_count)]
#         input_ids = input_ids[: (max_seq_len - special_tokens_count)]
#         attention_mask = attention_mask[: (max_seq_len - special_tokens_count)]

#     ## Adding End of sentence token
#     input_ids = input_ids + [tokenizer.eos_token_id]
#     bbox_according_to_tokenizer = bbox_according_to_tokenizer + [eos_token_box]
#     attention_mask = attention_mask + [1]

#     # Padding
#     if max_seq_len != -1 and len(input_ids) < max_seq_len:
#         pad_length = max_seq_len - len(input_ids)

#         input_ids = input_ids + [tokenizer.pad_token_id] * (pad_length)
#         bbox_according_to_tokenizer = bbox_according_to_tokenizer + \
#             [pad_token_box] * (pad_length)
#         attention_mask = attention_mask + [0] * (pad_length)

#     return dict(input_ids = input_ids, bboxes = bbox_according_to_tokenizer, attention_mask = attention_mask)

In [5]:
from transformers import AutoTokenizer, PretrainedConfig, AutoConfig
from torchvision.transforms import Compose, Normalize, Resize, ToTensor

config = {
    'model_name' : 'google-t5/t5-small', ## can be 'google-t5/t5-small', 'google-t5/t5-base', 'google-t5/t5-large'
    'image_model' : 'microsoft/resnet-50',
    'image_mean' : [0.485, 0.456, 0.406], ## resnet-50 configuration
    'image_std' : [0.229, 0.224, 0.225],
    'image_size' : (224, 224),
    'patch_size' : (2, 2),
    'num_channels' : 3,
    'max_image_tokens' : 128, ## Found from ablation
    'max_2d_position_embeddings' : 1024, 
    
}

t5_config = AutoConfig.from_pretrained(config['model_name'])
t5_config.update(config)
config = t5_config
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
image_transform = Compose([
            Resize(config.image_size if type(config.image_size) == tuple else (config.image_size, config.image_size)),
            ToTensor(),
            Normalize(mean=config.image_mean, std=config.image_std),
        ])

In [6]:
orig_size = (1000, 1000) ## We kept the original scale to 1000, 1000 image size

## Normalizing the bounding boxes between 0 to 1000
raw_datasets['train'] = raw_datasets['train'].map(lambda x : {'bbox' : [normalize_box(a, orig_size[0], orig_size[1]) for a in x['bbox']]}, batched=False)

raw_datasets['train'] = raw_datasets['train'].map(lambda x : get_tokens_with_boxes(x['bbox'], x['words'], tokenizer=tokenizer,
                                                                 max_seq_len=tokenizer.model_max_length), batched=False, 
                                                                 remove_columns = ['bbox','words'])

In [7]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['img', 'input_ids', 'bboxes', 'attention_mask'],
        num_rows: 2
    })
})

In [8]:
def preprocess(example_batch):
    batch = {}
    batch["pixel_values"] = [
        image_transform(img) for img in example_batch["img"]
    ]

    batch['input_ids'] = [torch.tensor(ids).long() for ids in example_batch["input_ids"]]
    batch['bbox'] = [torch.tensor(box).long() for box in example_batch["bboxes"]]
    batch['attention_mask'] = [torch.tensor(mask).long() for mask in example_batch["attention_mask"]]

    return batch

In [9]:
raw_datasets.set_transform(preprocess)

## Modeling

1. Image Embedding

In [10]:
# import torch.nn as nn
# import collections.abc

# ## Ref:https://github.com/huggingface/transformers/blob/ff841900e45763114d2417fb24ce29d950c6c956/src/transformers/models/vit/modeling_vit.py#L146
# class PatchEmbeddings(nn.Module):
#     """
#     This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
#     `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
#     Transformer.
#     """

#     def __init__(self, config):
#         super().__init__()
#         image_size, patch_size = config.image_size, config.patch_size
#         num_channels, hidden_size = config.num_channels, config.d_model

#         image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
#         patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
#         num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
#         self.image_size = image_size
#         self.patch_size = patch_size
#         self.num_channels = num_channels
#         self.num_patches = num_patches
#         self.max_image_tokens = config.max_image_tokens ## If we limit the max_image_tokens to a number, would it capture the global context?
#         ## Should we keep the convolution kernel's size to be (16, 16) rather than just (2, 2), so sequence length can be reduced and we are
#         ## able to capture global context?

#         self.conv_projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
#         self.linear_projection = nn.Linear(hidden_size, hidden_size)

#         self.positional_embedding = nn.Embedding(self.max_image_tokens, hidden_size)


#     def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
#         _, num_channels, height, width = pixel_values.shape
#         if num_channels != self.num_channels:
#             raise ValueError(
#                 "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
#                 f" Expected {self.num_channels} but got {num_channels}."
#             )
#         if not interpolate_pos_encoding:
#             if height != self.image_size[0] or width != self.image_size[1]:
#                 raise ValueError(
#                     f"Input image size ({height}*{width}) doesn't match model"
#                     f" ({self.image_size[0]}*{self.image_size[1]})."
#                 )
#         embeddings = self.conv_projection(pixel_values).flatten(2).transpose(1, 2)
#         embeddings = self.linear_projection(embeddings)[:, :self.max_image_tokens, :]

#         positions = torch.arange(0, self.max_image_tokens).unsqueeze(0).to(embeddings.device)
#         position_embedding = self.positional_embedding(positions)

#         return embeddings + position_embedding

In [11]:
img_feature_extractor = PatchEmbeddings(config)
sample_img_emb = img_feature_extractor(raw_datasets['train'][0]['pixel_values'].unsqueeze(0))

2. Text and Spatial Embedding

In [12]:
# ## Ref: https://github.com/uakarsh/latr/blob/1e73c1a99f9a0db4d85177259226148a65556069/src/new_latr/modeling.py#L11

# class SpatialModule(nn.Module):
#     def __init__(self, config):
#         super().__init__()
#         self.top_left_x = nn.Embedding(
#             config.max_2d_position_embeddings, config.d_model // 2)
#         self.bottom_right_x = nn.Embedding(
#             config.max_2d_position_embeddings, config.d_model // 2)
#         self.top_left_y = nn.Embedding(
#             config.max_2d_position_embeddings, config.d_model // 2)
#         self.bottom_right_y = nn.Embedding(
#             config.max_2d_position_embeddings, config.d_model // 2)
#         self.width_emb = nn.Embedding(config.max_2d_position_embeddings, config.d_model)
#         self.height_emb = nn.Embedding(
#             config.max_2d_position_embeddings, config.d_model)

#     def forward(self, coordinates):

#         top_left_x_feat = self.top_left_x(coordinates[:, :, 0])
#         top_left_y_feat = self.top_left_y(coordinates[:, :, 1])
#         bottom_right_x_feat = self.bottom_right_x(coordinates[:, :, 2])
#         bottom_right_y_feat = self.bottom_right_y(coordinates[:, :, 3])
#         width_feat = self.width_emb(coordinates[:, :, 2] - coordinates[:, :, 0])
#         height_feat = self.height_emb(coordinates[:, :, 3] - coordinates[:, :, 1])

#         layout_feature = torch.cat([top_left_x_feat, bottom_right_x_feat], axis = -1) + torch.cat([top_left_y_feat, bottom_right_y_feat], axis = -1) + \
#              width_feat + height_feat
#         return layout_feature

In [13]:
spatial_feature_extractor = SpatialModule(config)
spatial_feat = spatial_feature_extractor(raw_datasets['train'][0]['bbox'].unsqueeze(0))

In [14]:
raw_datasets['train'][0].keys()

dict_keys(['pixel_values', 'input_ids', 'bbox', 'attention_mask'])

In [15]:
raw_datasets['train'][0]['bbox'].shape

torch.Size([512, 4])

In [16]:
# from transformers import T5ForConditionalGeneration
# from torch.nn import CrossEntropyLoss
# from transformers.modeling_outputs import (
#     Seq2SeqLMOutput,
# )

# class DocFormerV2(T5ForConditionalGeneration):
#     def __init__(self, config):
#         super().__init__(config=config)
#         self.spatial_feat_extractor = SpatialModule(config)
#         self.img_feature_extractor = PatchEmbeddings(config)
#         self.modality_embedding = nn.Embedding(2, config.d_model)

#     def forward(
#             self,
#             input_ids=None,
#             bbox=None,
#             attention_mask=None,
#             decoder_input_ids=None,
#             decoder_attention_mask=None,
#             encoder_outputs=None,
#             past_key_values=None,
#             pixel_values=None,
#             labels=None,
#             head_mask=None,
#             inputs_embeds=None,
#             decoder_inputs_embeds=None,
#             decoder_head_mask=None,
#             cross_attn_head_mask=None,
#             use_cache=True,
#             output_attentions=None,
#             output_hidden_states=None,
#             return_dict=None,
#             **kwargs,) :

#         use_cache = use_cache if use_cache is not None else self.config.use_cache
#         return_dict = return_dict if return_dict is not None else self.config.use_return_dict

#         if decoder_input_ids is None and labels is not None:
#             decoder_input_ids = self._shift_right(labels)

#         # Encode if needed (training, first prediction pass)
#         if encoder_outputs is None:
#             inputs_embeds, attention_mask = self.calculate_embedding(
#                 pixel_values, bbox, input_ids, attention_mask)
#             encoder_outputs = self.encoder(
#                 attention_mask=attention_mask,
#                 inputs_embeds=inputs_embeds,
#                 head_mask=head_mask,
#                 output_attentions=output_attentions,
#                 output_hidden_states=output_hidden_states,
#                 return_dict=return_dict,
#             )
#         hidden_states = encoder_outputs[0]

#         if decoder_input_ids == None:
#             decoder_input_ids = self._shift_right(input_ids)

#         # Decode
#         decoder_outputs = self.decoder(
#             input_ids=decoder_input_ids,
#             attention_mask=decoder_attention_mask,
#             inputs_embeds=decoder_inputs_embeds,
#             past_key_values=past_key_values,
#             encoder_hidden_states=hidden_states,
#             encoder_attention_mask=attention_mask,
#             head_mask=decoder_head_mask,
#             cross_attn_head_mask=cross_attn_head_mask,
#             use_cache=use_cache,
#             output_attentions=output_attentions,
#             output_hidden_states=output_hidden_states,
#             return_dict=return_dict,
#         )

#         sequence_output = decoder_outputs[0]

#         if self.config.tie_word_embeddings:
#             # Rescale output before projecting on vocab
#             # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
#             sequence_output = sequence_output * (self.config.d_model**-0.5)

#         lm_logits = self.lm_head(sequence_output)

#         loss = None
#         if labels is not None:
#             loss_fct = CrossEntropyLoss(ignore_index=-100)
#             loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))

#         if not return_dict:
#             output = (lm_logits,) + \
#                 decoder_outputs[2:] + (encoder_outputs[0],) + encoder_outputs[2:]
#             return ((loss,) + output) if loss is not None else output

#         return Seq2SeqLMOutput(
#             loss=loss,
#             logits=lm_logits,
#             past_key_values=decoder_outputs.past_key_values,
#             decoder_hidden_states=decoder_outputs.hidden_states,
#             decoder_attentions=decoder_outputs.attentions,
#             cross_attentions=decoder_outputs.cross_attentions,
#             encoder_last_hidden_state=encoder_outputs.last_hidden_state,
#             encoder_hidden_states=encoder_outputs.hidden_states,
#             encoder_attentions=encoder_outputs.attentions,
#         )
    
#     def _shift_right(self, input_ids):
#         decoder_start_token_id = self.config.decoder_start_token_id
#         pad_token_id = self.config.pad_token_id

#         if decoder_start_token_id is None:
#             raise ValueError(
#                 "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
#                 "See T5 docs for more information."
#             )

#         # shift inputs to the right
#         if is_torch_fx_proxy(input_ids):
#             # Item assignment is not supported natively for proxies.
#             shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
#             shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
#         else:
#             shifted_input_ids = input_ids.new_zeros(input_ids.shape)
#             shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
#             shifted_input_ids[..., 0] = decoder_start_token_id

#         if pad_token_id is None:
#             raise ValueError("self.model.config.pad_token_id has to be defined.")
#         # replace possible -100 values in labels by `pad_token_id`
#         shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

#         return shifted_input_ids

#     def prepare_inputs_for_generation(
#         self,
#         input_ids,
#         past_key_values=None,
#         attention_mask=None,
#         head_mask=None,
#         decoder_head_mask=None,
#         cross_attn_head_mask=None,
#         use_cache=None,
#         encoder_outputs=None,
#         **kwargs,
#     ):
#         # cut decoder_input_ids if past is used
#         if past_key_values is not None:
#             input_ids = input_ids[:, -1:]

#         return {
#             "decoder_input_ids": input_ids,
#             "past_key_values": past_key_values,
#             "encoder_outputs": encoder_outputs,
#             "attention_mask": attention_mask,
#             "head_mask": head_mask,
#             "decoder_head_mask": decoder_head_mask,
#             "cross_attn_head_mask": cross_attn_head_mask,
#             "use_cache": use_cache,
#             "bbox": kwargs.get("bbox", None),
#             "pixel_values": kwargs.get("pixel_values", None),
#         }

#     def calculate_embedding(self, img, bbox, input_ids, attention_mask):
#         img_feat = self.img_feature_extractor(img)
#         spatial_feat = self.spatial_feat_extractor(bbox)
#         language_feat = self.shared(input_ids)

#         layout_feat = spatial_feat + language_feat
#         img_modality_token = self.modality_embedding(torch.zeros(1, img_feat.shape[1]).long().to(self.device))
#         lang_modality_token = self.modality_embedding(torch.ones(1, language_feat.shape[1]).long().to(self.device))

#         img_feat += img_modality_token
#         layout_feat += lang_modality_token

#         multi_modal_feat = torch.cat([img_feat, layout_feat], axis=1)
#         input_attention_mask = torch.cat(
#             [torch.ones(img_feat.shape[:2]).to(img_feat.device), attention_mask], axis=1)
        
#         return multi_modal_feat, input_attention_mask

In [17]:
sample = raw_datasets['train'][0]
for key, _ in list(sample.items()):
    sample[key] = sample[key].unsqueeze(0)

In [18]:
docformer_v2 = DocFormerV2(config)

In [19]:
sample['pixel_values'].shape

torch.Size([1, 3, 224, 224])

In [20]:
output = docformer_v2(**sample)

In [23]:
output

Seq2SeqLMOutput(loss=None, logits=tensor([[[ 6.2223, -0.0186, -0.4801,  ..., -0.6275,  0.9068, -0.2347],
         [ 0.9536, -1.4644,  2.2863,  ...,  0.5846,  0.4761, -0.5474],
         [ 0.0564, -0.3024,  1.2863,  ..., -0.5176,  0.5124, -1.1346],
         ...,
         [ 0.4013,  0.9506,  1.5823,  ...,  1.9127, -0.8710, -0.0522],
         [ 0.6935, -0.1235,  2.1049,  ..., -0.7052,  1.2970, -0.6487],
         [ 0.5439, -1.3241,  1.1074,  ..., -0.2089,  0.6928, -1.1619]]],
       grad_fn=<UnsafeViewBackward0>), past_key_values=((tensor([[[[-0.0593,  0.2877, -0.2666,  ...,  1.1694,  0.3735,  0.1573],
          [ 1.3052, -1.8031, -0.9118,  ..., -1.8054, -0.3311, -0.1553],
          [ 1.9435,  0.6398,  0.7937,  ..., -0.2507,  0.7002,  0.6433],
          ...,
          [ 0.6455,  0.8375, -1.2233,  ..., -0.0106, -0.5837,  0.5068],
          [ 1.0143, -0.9986, -1.8679,  ...,  0.4936, -0.5922,  0.8664],
          [ 0.9427, -0.4008,  2.6253,  ..., -0.6578, -0.3355, -1.6953]],

         [[ 0.1454