<a href="https://colab.research.google.com/github/uakarsh/TiLT-Implementation/blob/main/how_did_i_prepare_the_stuffs/tilt_part_3_1_aligning_all_the_parts_to_make_tilt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/uakarsh/TiLT-Implementation.git

fatal: destination path 'TiLT-Implementation' already exists and is not an empty directory.


In [None]:
!pip install -r /content/TiLT-Implementation/requirements.txt

In [None]:
import sys
sys.path.append("/content/TiLT-Implementation/src/")

In [None]:
from transformers import AutoTokenizer, AutoConfig
from datasets import load_dataset
import torch
import torch.nn as nn

from dataset import FUNSDDs
from torchvision import transforms
from tqdm.auto import tqdm

## Custom imports
from visual_backbone import Unet_encoder, RoIPool
from t5 import T5ForConditionalGeneration, T5Stack
from transformers import AutoModel

## 1.1. Preparing the dataset

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

hf_ds = load_dataset("nielsr/funsd-layoutlmv3")
model_name = "t5-base"
## Visual Embedding extractor's parameters
in_channels = 3
num_pool_layers = 3
channels = 16
sampling_ratio = 2
spatial_scale = 48 / 384
output_size = (3,3)
load_weights = True

## Tokenizer's parameter
model_max_length = 512

t5_config = AutoConfig.from_pretrained(model_name)
## Adding new parameters
t5_config.update(dict(in_channels = in_channels, num_pool_layers = num_pool_layers,  channels = channels, model_max_length = model_max_length,
                      output_size = output_size, spatial_scale = spatial_scale, sampling_ratio = sampling_ratio, use_cache = False, load_weights = load_weights))

## Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast = True, model_max_length = model_max_length)



  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
def get_id2label_and_label2id():
    label2id = {'O': 0, 'B-HEADER': 1, 'I-HEADER': 2, 'B-QUESTION': 3, 'I-QUESTION': 4, 'B-ANSWER': 5, 'I-ANSWER': 6}
    id2label = {0: 'O', 1: 'B-HEADER', 2: 'I-HEADER', 3: 'B-QUESTION', 4: 'I-QUESTION', 5: 'B-ANSWER', 6: 'I-ANSWER'}
    return id2label, label2id

def convert_id_to_label(list_of_label):
  return [id2label[x] for x in list_of_label]

In [None]:
id2label, label2id = get_id2label_and_label2id()
transform = transforms.Compose([transforms.ToTensor(), 
                                transforms.Lambda(lambda x : 2 * x - 1)])

In [None]:
train_new_tags = list(map(lambda x : convert_id_to_label(x), hf_ds['train']['ner_tags']))
test_new_tags = list(map(lambda x : convert_id_to_label(x), hf_ds['test']['ner_tags']))

In [None]:
hf_ds['train'] = hf_ds['train'].remove_columns("ner_tags").add_column("ner_tags", train_new_tags)
hf_ds['test'] = hf_ds['test'].remove_columns("ner_tags").add_column("ner_tags", test_new_tags)

In [None]:
train_ds = FUNSDDs(hf_ds['train'],tokenizer = tokenizer, transform = transform)
val_ds = FUNSDDs(hf_ds['test'],tokenizer = tokenizer, transform = transform)

### 1.2 Writing the `collate_fn` for custom handling of the dataloader

In [None]:
class CollateFn(object):
  def __init__(self, tokenizer):
    self.tokenizer = tokenizer

  def __call__(self, list_of_ds):
    simple_keys = ["input_ids", "attention_mask", "bboxes", "pixel_values" ]
    actual_batch = {}
    for key in simple_keys:
      actual_batch[key] = torch.stack([x[key] for x in list_of_ds])
    
    actual_batch['labels'] = self.tokenizer.batch_encode_plus([x['labels'] for x in list_of_ds], return_tensors = 'pt', is_split_into_words = True,
                                                              padding='max_length', truncation = True)['input_ids']
    return actual_batch

In [None]:
collate_fn = CollateFn(tokenizer)

In [None]:
# sample_batch_encoding = collate_fn([train_ds[0], train_ds[1]])
# for key in sample_batch_encoding:
#   sample_batch_encoding[key] = sample_batch_encoding[key].to(device)
# #   print(f"Key : {key}, has shape : {sample_batch_encoding[key].shape}")

## 2.1 Preparing the visual model

In [None]:
class VisualEmbedding(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.unet_encoder = Unet_encoder(in_channels = config.in_channels, channels = config.channels, num_pool_layers = config.num_pool_layers)
    self.roi_pool = RoIPool(output_size = config.output_size, spatial_scale = config.spatial_scale)
    self.proj = nn.Linear(in_features = 128 * 3 * 3, out_features = config.d_model)
    self.config = config

  def forward(self, pixel_values, bboxes):
    image_embedding = self.unet_encoder(pixel_values)
    feature_maps_bboxes = self.roi_pool(image_embedding, bboxes).flatten(2)
    projection = self.proj(feature_maps_bboxes)
    return projection

In [None]:
# visual_embedding_extractor = VisualEmbedding(t5_config).to(device)

In [None]:
# visual_embedding = visual_embedding_extractor(pixel_values = sample_batch_encoding['pixel_values'], bboxes = sample_batch_encoding['bboxes'])

## 2.2 Preparing the semantic model

In [None]:
# t5_model = T5ForConditionalGeneration(t5_config).to(device)

In [None]:
# ## Forward method

# ## Semantic embedding from t5_model's embedding layer
# semantic_embedding = t5_model.shared(sample_batch_encoding['input_ids'])

# ## Net embedding is addition of both the embeddings
# total_embedding = visual_embedding + semantic_embedding

# ## This is then fed to t5_model
# final_output = t5_model(attention_mask = sample_batch_encoding['attention_mask'], inputs_embeds = total_embedding,
#                         labels = sample_batch_encoding['labels'])

In [None]:
## Some rough work

# pretrained_t5_model = AutoModel.from_pretrained(model_name)

# for (name, param), (name_1, param_1) in zip(pretrained_t5_model.named_parameters(), t5_model.named_parameters()): 
#   if name.startswith("decoder"):
#     print(f"{name}   {name_1}")

# t5_model_sd = t5_model.state_dict()
# t5_model_sd_keys = t5_model_sd.keys()

# pretrained_t5_model_sd = pretrained_t5_model.state_dict()
# pretrained_t5_model_sd_keys = pretrained_t5_model_sd.keys()

# t5_model_sd_keys = [k for k in t5_model_sd_keys if not any(["relative_horizontal_bias" in k, "relative_vertical_bias" in k])] # discard this mask / buffer, not a param

# t5_model.load_state_dict(pretrained_t5_model.state_dict(), strict = False)

In [None]:
class TiLTTransformer(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.visual_embedding_extractor = VisualEmbedding(config)
    self.t5_model = T5ForConditionalGeneration(config)
    

  def generate(self, batch):
    total_embedding = self.common_step(batch)
    return self.t5_model.generate(input_embeds = total_embedding)

  def common_step(self, batch):
    ## Visual embedding
    visual_embedding = self.visual_embedding_extractor(pixel_values = batch['pixel_values'], bboxes = batch['bboxes'])

    ## Semantic embedding from t5_model's embedding layer
    semantic_embedding = self.t5_model.shared(batch['input_ids'])

    ## Net embedding is addition of both the embeddings
    total_embedding = visual_embedding + semantic_embedding

    return total_embedding

  def forward(self, batch):

    total_embedding = self.common_step(batch)

    ## This is then fed to t5_model
    final_output = self.t5_model(attention_mask = batch['attention_mask'], inputs_embeds = total_embedding,
                            labels = batch['labels'])
    
    return final_output

In [None]:
# tilt_model = TiLTTransformer(t5_config).to(device)
# output = tilt_model(sample_batch_encoding)

## Checking out the parameters of all the models that have been mentioned in the paper

In [None]:
T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "t5-base",
    # "t5-large"
]

for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST:
  t5_config = AutoConfig.from_pretrained(model_name)
  t5_config.update(dict(in_channels = in_channels, num_pool_layers = num_pool_layers,  channels = channels, model_max_length = model_max_length,
                      output_size = output_size, spatial_scale = spatial_scale, sampling_ratio = sampling_ratio, use_cache = False, load_weights = load_weights))
  tilt_model = TiLTTransformer(t5_config)
  print(f"Model : {model_name} has {sum(p.numel() for p in tilt_model.parameters()) / 1e6:.4f} M parameters")

Weights loaded successfully!
Model : t5-base has 224.2795 M parameters


## In the paper, they reported: 230M and 780M, and we have got 225M and 740M, not sure, where am I missing, maybe in the visual backbone? But, I guess we can continue with this for now

In [None]:
# from transformers import AutoModel
# for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST:
#   tilt_model = AutoModel.from_pretrained(model_name)
#   print(f"Model : {model_name} has {sum(p.numel() for p in tilt_model.parameters()) / 1e6:.4f} M parameters")