In [1]:
import torch
from torch import nn

In [2]:
!pip install -q transformers
!pip install -q evaluate
!pip install -q rouge_score
!pip install --upgrade nltk
!pip install -q git+https://github.com/salaniz/pycocoevalcap.git

Collecting nltk
  Downloading nltk-3.8.1-py3-none-any.whl (1.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Installing collected packages: nltk
  Attempting uninstall: nltk
    Found existing installation: nltk 3.2.4
    Uninstalling nltk-3.2.4:
      Successfully uninstalled nltk-3.2.4
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
preprocessing 0.1.13 requires nltk==3.2.4, but you have nltk 3.8.1 which is incompatible.[0m[31m
[0mSuccessfully installed nltk-3.8.1


In [3]:
import nltk
nltk.__version__



'3.8.1'

In [4]:
import glob
import nltk
import math
import torch
import numpy as np
from torch import nn
from random import choice
from tqdm.auto import tqdm
from pycocoevalcap.cider.cider import Cider
import torch.nn.functional as F
from torch.utils.data import Dataset
from nltk.tokenize import word_tokenize
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
nltk.download('punkt')

class ViTextCapsDataset(Dataset):
  def __init__(self, tokenizer, obj_features_folder=None, ocr_features_folder=None, fasttext_token=None, caption_file=None, convert_obj_bbx=True, data=None, mask_value=64000.0):

    self.convert_obj_bbx = convert_obj_bbx
    self.tokenizer = tokenizer
    self.mask_value = mask_value
    self.dummy_tensor = torch.ones((1, 300))

    # if data is None all others argument cant be None
    assert (obj_features_folder is not None and ocr_features_folder is not None \
    and fasttext_token is not None and caption_file is not None) or data is not None, "All other arguments must be passed if data is None!"

    if data is None:
      self.load_data(obj_features_folder, ocr_features_folder, fasttext_token, caption_file)
    else:
      self.data = data

  def load_data(self, obj_features_folder, ocr_features_folder, fasttext_token, caption_file):
    self.data = []

    # Setup the total feature file
    data_paths = glob.glob(ocr_features_folder + '/*')
    fasttext_token_ = np.load(fasttext_token, allow_pickle=True).tolist()
    captions = np.load(caption_file, allow_pickle=True).tolist()

    for path in tqdm(data_paths, desc='Load Data'):

      # Load the feature from each file
      image_name = path.split('/')[-1].split('.')[0]
      ocr_feature = np.load(ocr_features_folder + '/' + image_name + '.npy', allow_pickle=True).tolist()
      obj_info = np.load(obj_features_folder + '/' + image_name + '_info.npy', allow_pickle=True).tolist()
      obj_feature = np.load(obj_features_folder + '/' + image_name + '.npy')

      try:
        sample = {
            'id': image_name,
            'captions': captions[image_name],

            'ocr': {
                'boxes': ocr_feature['boxes'],
                'scores': ocr_feature['scores'], # The confidence score of the model's prediction
                'size': (ocr_feature['weight'], ocr_feature['height']), # xmin, ymin, xmax, ymax
                'texts': ocr_feature['texts'],
                'fasttext_token': fasttext_token_[image_name],
                'rec_features': ocr_feature['rec_features'],
                'det_features': ocr_feature['det_features']
                },

            'obj': {
                'boxes': obj_info['bbox'],
                'scores': obj_info['cls_prob'], # The confidence score of the model's prediction
                'size': (obj_info['image_width'], obj_info['image_height']), # xmin, ymin, xmax, ymax
                'objects': obj_info['objects'], # With int type, the object's label
                'features': obj_feature,
            }
        }
      except:
        continue

      if self.convert_obj_bbx:
        for i in range(len(sample['obj']['boxes'])):
          sample['obj']['boxes'][i] = self.convert(sample['ocr']['size'], sample['obj']['boxes'][i]) # Use ocr-size because the obj-size is already scaled

      self.data.append(sample)

  def convert(self, size, box):
    # size: width, height
    # box: xmin, ymin, xmax, ymax
    w, h = size
    return (box[0] / w, box[1] / h, box[2] / w, box[3] / h)

  def __getitem__(self, idx):
    sample = self.data[idx]
    return {
            'id': sample['id'],
            'captions': sample['captions'],
            'obj_boxes': torch.tensor(sample['obj']['boxes']),
            'obj_features': torch.tensor(sample['obj']['features']),
            'ocr_texts': sample['ocr']['texts'],
            'ocr_boxes': torch.tensor(sample['ocr']['boxes']),
            'ocr_token_embeddings': torch.tensor(sample['ocr']['fasttext_token']) if len(sample['ocr']['fasttext_token']) > 0 else self.dummy_tensor,
            'ocr_rec_features': torch.tensor(sample['ocr']['rec_features']),
            'ocr_det_features': torch.tensor(sample['ocr']['det_features'])
        }

  def split_data(self, validation_size, test_size, random_state=42):

    test_val_size = test_size + validation_size

    # Split train and evaluation set
    train_data, test_val_data = train_test_split(self.data,
                                                 test_size=test_val_size,
                                                 random_state=random_state)

    # Split val and test
    val_data, test_data = train_test_split(test_val_data,
                                           test_size=test_size / test_val_size,
                                           random_state=random_state)

    return (ViTextCapsDataset(tokenizer=self.tokenizer, mask_value=self.mask_value, data=train_data),
            ViTextCapsDataset(tokenizer=self.tokenizer, mask_value=self.mask_value, data=val_data),
            ViTextCapsDataset(tokenizer=self.tokenizer, mask_value=self.mask_value, data=test_data))

  def __len__(self):
    return len(self.data)

  def collate_fn(self, batch):

    batch = dict(zip(batch[0].keys(), zip(*[d.values() for d in batch])))

    # Convert obj list to tensor
    obj_boxes_tensor = torch.stack(batch['obj_boxes'])
    obj_features_tensor = torch.stack(batch['obj_features'])

    # Convert ocr list to tensor
    ocr_boxes_tensor = pad_sequence(batch['ocr_boxes'], batch_first=True, padding_value=self.mask_value)
    ocr_token_embeddings_tensor = pad_sequence(batch['ocr_token_embeddings'], batch_first=True, padding_value=1)
    ocr_rec_features_tensor = pad_sequence(batch['ocr_rec_features'], batch_first=True, padding_value=1)
    ocr_det_features_tensor = pad_sequence(batch['ocr_det_features'], batch_first=True, padding_value=1)

    captions_ = [choice(s) for s in batch['captions']]
    raw_captions = batch['captions']
    batch_id = batch['id']
    texts_ = batch['ocr_texts']

    vs = self.tokenizer.vocab_size + 1
    labels_= []

    # Captions to token
    for i, caption in enumerate(captions_):
      label_ = []

      for token in word_tokenize(caption):

          if token in texts_[i] and token not in self.tokenizer.get_vocab():
            label_.append(texts_[i].index(token) + vs)
          else:
            label_ += self.tokenizer(token)['input_ids'][1: -1]

      label_.append(2) # 2 is <eos> in tokenizer
      labels_.append(torch.tensor(label_))

    # Convert labels_ 2 tensor
    labels_ = pad_sequence(labels_, batch_first=True, padding_value=1)

    dec_mask = torch.ones_like(labels_)
    dec_mask = dec_mask.masked_fill(labels_ == 1, 0) # batch_size, seq_length

    # Get the ocr_attention_mask
    ocr_attn_mask = torch.ones_like(ocr_boxes_tensor)
    ocr_attn_mask = ocr_attn_mask.masked_fill(ocr_boxes_tensor == self.mask_value, 0)[:, :, 0] # batch_size, seq_length
    ocr_boxes_tensor = ocr_boxes_tensor.masked_fill(ocr_boxes_tensor == self.mask_value, 1)

    # Join attention_mask
    obj_attn_mask = torch.ones(size=(obj_boxes_tensor.size(0), obj_boxes_tensor.size(1))) # batch_size, seq_length
    join_attn_mask = torch.cat([obj_attn_mask, ocr_attn_mask, dec_mask], dim=-1)

    return {
          'id': batch_id,
          'obj_boxes': obj_boxes_tensor,
          'obj_features': obj_features_tensor,
          'ocr_boxes': ocr_boxes_tensor,
          'ocr_token_embeddings': ocr_token_embeddings_tensor,
          'ocr_rec_features': ocr_rec_features_tensor,
          'ocr_det_features': ocr_det_features_tensor,
          'join_attn_mask': join_attn_mask,
          'labels': labels_,
          'texts': texts_,
          'raw_captions': raw_captions
    }

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [5]:
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2")
phobert_model = AutoModel.from_pretrained("vinai/phobert-base-v2")

data = ViTextCapsDataset(tokenizer,
                         '/kaggle/input/vitextcaps2/combine_obj_features',
                         '/kaggle/input/vitextcaps2/combine_ocr_features',
                         '/kaggle/input/vitextcaps2/token_fasttext_3.npy',
                         '/kaggle/input/vitextcaps2/clean_captions.npy')

config.json:   0%|          | 0.00/678 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/895k [00:00<?, ?B/s]

bpe.codes:   0%|          | 0.00/1.14M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/3.13M [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/540M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at vinai/phobert-base-v2 and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Load Data:   0%|          | 0/5289 [00:00<?, ?it/s]

In [6]:
train_data, val_data, test_data = data.split_data(validation_size=0.2, test_size=0.1)

In [7]:
len(train_data), len(val_data), len(test_data)

(3271, 935, 468)

In [8]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_data, batch_size=16, shuffle=True, collate_fn=data.collate_fn)
val_dataloader = DataLoader(val_data, batch_size=16, shuffle=False, collate_fn=data.collate_fn)
test_dataloader = DataLoader(test_data, batch_size=16, shuffle=False, collate_fn=data.collate_fn)

In [9]:
class ObjectEncoder(nn.Module):
  def __init__(self, obj_in_dim, hidden_size, dropout_prob=0.1):
    super().__init__()

    # 2048 (FasterRCNN)
    self.linear_obj_feat_to_mmt_in = nn.Linear(obj_in_dim, hidden_size)

    # OBJ location feature
    self.linear_obj_bbox_to_mmt_in = nn.Linear(4, hidden_size)

    self.obj_feat_layer_norm = nn.LayerNorm(hidden_size)
    self.obj_bbox_layer_norm = nn.LayerNorm(hidden_size)
    self.dropout = nn.Dropout(dropout_prob)

  def forward(self, obj_boxes, obj_features):

    # Features to hidden size
    obj_features = F.normalize(obj_features, dim=-1)

    # Get obj features
    obj_features = self.obj_feat_layer_norm(self.linear_obj_feat_to_mmt_in(obj_features))
    obj_bbox_features = self.obj_bbox_layer_norm(self.linear_obj_bbox_to_mmt_in(obj_boxes))

    return self.dropout(obj_features + obj_bbox_features) # batch_size, seq_length, hidden_size

In [10]:
class OCREncoder(nn.Module):
  def __init__(self, ocr_in_dim, hidden_size, dropout_prob=0.1):
    super().__init__()

    # 300 (FastText) + 256 (rec_features) + 256 (det_features) = 812 # 768
    self.linear_ocr_feat_to_mmt_in = nn.Linear(ocr_in_dim, hidden_size)

    # OCR location feature
    self.linear_ocr_bbox_to_mmt_in = nn.Linear(4, hidden_size)

    self.ocr_feat_layer_norm = nn.LayerNorm(hidden_size)
    self.ocr_bbox_layer_norm = nn.LayerNorm(hidden_size)
    self.dropout = nn.Dropout(dropout_prob)

  def forward(self, ocr_boxes, ocr_token_embeddings, ocr_rec_features, ocr_det_features):

    # Normalize input
    ocr_token_embeddings = F.normalize(ocr_token_embeddings, dim=-1)
    ocr_rec_features = F.normalize(ocr_rec_features, dim=-1)
    ocr_det_features = F.normalize(ocr_det_features, dim=-1)

    # get OCR combine features
    ocr_combine_features = torch.cat([ocr_token_embeddings, ocr_rec_features, ocr_det_features], dim=-1)
    ocr_combine_features = self.ocr_feat_layer_norm(self.linear_ocr_feat_to_mmt_in(ocr_combine_features))

    # Get OCR bbox features
    ocr_bbox_features = self.ocr_bbox_layer_norm(self.linear_ocr_bbox_to_mmt_in(ocr_boxes))

    return self.dropout(ocr_combine_features + ocr_bbox_features) # batch_size, seq_length, hidden_size

In [11]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model, max_len=512, dropout_prob=0.1):
    super().__init__()

    self.dropout = nn.Dropout(dropout_prob)

    position_ids = torch.arange(max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
    pe = torch.zeros(size=(1, max_len, d_model))
    pe[0, :, 0::2] = torch.sin(position_ids / div_term)
    pe[0, :, 1::2] = torch.cos(position_ids / div_term)
    self.register_buffer('pe', pe)

  def forward(self, x):
    # x shape (batch_size, seq_length, d_model)
    return x + self.pe[:, :x.size(1), :]

def _batch_gather(x, inds):
    assert x.dim() == 3
    batch_size = x.size(0)
    length = x.size(1)
    dim = x.size(2)

    batch_offsets = torch.arange(batch_size, device=inds.device) * length
    batch_offsets = batch_offsets.unsqueeze(-1)
    assert batch_offsets.dim() == inds.dim()
    results = F.embedding(batch_offsets + inds, x.view(batch_size * length, dim)) # batch_size, T, hidden_size
    return results

class PrevPredEmbeddings(nn.Module):
    def __init__(self, hidden_size, ln_eps=1e-12, dropout_prob=0.1):
        super().__init__()

        self.position_embeddings = PositionalEncoding(hidden_size)
        self.token_type_embeddings = nn.Embedding(2, hidden_size)
        self.token_type_embeddings.weight.data = nn.Parameter(torch.cat([torch.zeros(1, hidden_size), torch.ones(1, hidden_size)]))

        self.ans_layer_norm = nn.LayerNorm(hidden_size, eps=ln_eps)
        self.ocr_layer_norm = nn.LayerNorm(hidden_size, eps=ln_eps)
        self.emb_layer_norm = nn.LayerNorm(hidden_size, eps=ln_eps)
        self.emb_dropout = nn.Dropout(dropout_prob)

    def forward(self, ans_emb, ocr_emb, labels):

        batch_size = labels.size(0)
        seq_length = labels.size(1)
        ans_num = ans_emb.size(0)

        # apply layer normalization to both answer embedding and OCR embedding
        # before concatenation, so that they have the same scale
        ans_emb = self.ans_layer_norm(ans_emb)
        ocr_emb = self.ocr_layer_norm(ocr_emb)
        assert ans_emb.size(-1) == ocr_emb.size(-1)

        # Token type ids: 0 -- vocab; 1 -- OCR
        token_type_embeddings = self.token_type_embeddings(labels.ge(ans_num).long()) # N, T, hidden_size
        embeddings = self.emb_dropout(self.emb_layer_norm(self.position_embeddings(token_type_embeddings)))

        return _batch_gather(torch.cat([ans_emb.unsqueeze(0).expand(batch_size, -1, -1), ocr_emb], dim=1), labels) + embeddings

In [12]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, n_heads, d_k, causal=False):
    super().__init__()

    self.n_heads = n_heads
    self.d_k = d_k

    self.query = nn.Linear(d_model, n_heads * d_k)
    self.key = nn.Linear(d_model, n_heads * d_k)
    self.value = nn.Linear(d_model, n_heads * d_k)

    self.fc = nn.Linear(n_heads * d_k, d_model)
    self.causal = causal

  def forward(self, x, dec_size, attention_mask=None):
    # x shape (batch_size, seq_length, d_model)

    N = x.size(0)
    T = x.size(1)

    # Pass through linear to get q, k, v
    q = self.query(x).view(N, T, self.n_heads, self.d_k).transpose(1, 2) # batch_size, n_heads, T, d_k
    k = self.key(x).view(N, T, self.n_heads, self.d_k).transpose(1, 2) # batch_size, n_heads, T, d_k
    v = self.query(x).view(N, T, self.n_heads, self.d_k).transpose(1, 2)

    # Get the attention scores
    attn_scores = (q @ k.mT) / math.sqrt(self.d_k) # batch_size, n_heads, T_dec, T_enc

    # Mask the padding values and (if causal)
    if attention_mask is not None:
      attn_scores = attn_scores.masked_fill(attention_mask[:, None, None, :] == 0, float('-inf'))

    if self.causal:
      causal_mask = torch.tril(torch.ones(dec_size, dec_size))
      extend_causal_mask = torch.ones((T, T))
      extend_causal_mask[:, -dec_size:] = torch.cat([torch.zeros((T - dec_size, dec_size)), causal_mask])
      extend_causal_mask = extend_causal_mask.to(attn_scores.device)

      attn_scores = attn_scores.masked_fill(extend_causal_mask[None, None, :, :]==0, float('-inf'))

    # Get the attention weights
    attn_weights = torch.softmax(attn_scores, dim=-1)

    # Get the values
    A = attn_weights @ v # batch_size, n_heads, T_dec, d_k

    # Reshape to batch_size, T_dec, n_heads * d_k
    A = A.transpose(1, 2).contiguous().view(N, T, self.n_heads * self.d_k)

    return self.fc(A)

In [13]:
class DecoderBlock(nn.Module):
  def __init__(self, d_model, n_heads, d_k, dropout_prob=0.1):
    super().__init__()

    self.mha = MultiHeadAttention(d_model, n_heads, d_k, causal=True)
    self.ln1 = nn.LayerNorm(d_model)
    self.ln2 = nn.LayerNorm(d_model)
    self.ffn = nn.Sequential(
        nn.Linear(d_model, d_model * 4),
        nn.GELU(),
        nn.Linear(d_model * 4, d_model),
        nn.Dropout(dropout_prob),
    )
    self.dropout = nn.Dropout(dropout_prob)

  def forward(self, x, dec_size, attention_mask=None):

    x = self.ln1(x + self.mha(x, dec_size, attention_mask))
    x = self.ln2(x + self.ffn(x))
    return self.dropout(x)

In [14]:
class Decoder(nn.Module):
  def __init__(self, d_model, n_heads, d_k, n_layers):
    super().__init__()

    self.transformer_blocks = nn.Sequential(*[DecoderBlock(d_model, n_heads, d_k) for _ in range(n_layers)])

  def forward(self, x, dec_size, attention_mask=None):

    for block in self.transformer_blocks:
      x = block(x, dec_size, attention_mask)

    return x # N, T, hidden_size

In [15]:
class MMT(nn.Module):
  def __init__(self, d_model, n_heads, d_k, n_layers):
    super().__init__()
    self.prev_pred_embeddings = PrevPredEmbeddings(d_model)
    self.encoder = Decoder(d_model, n_heads, d_k, n_layers)

  def forward(self, obj_emb, fixed_ans_emb, ocr_emb, prev_inds, attention_mask):

    dec_emb = self.prev_pred_embeddings(fixed_ans_emb, ocr_emb, prev_inds)

    encoder_inputs = torch.cat([obj_emb, ocr_emb, dec_emb], dim=1) # batch_size, T (obj + ocr + dec), 768

    # Get the size for each cause we gonna need that in the pointer network
    obj_max_num = obj_emb.size(1)
    ocr_max_num = ocr_emb.size(1)
    dec_max_num = dec_emb.size(1)

    # offsets of each modality in the joint embedding space
    ocr_begin = obj_max_num
    ocr_end = ocr_begin + ocr_max_num

    encoder_outputs = self.encoder(encoder_inputs, dec_max_num, attention_mask) # N, T, hidden_size

    # mmt_dec_output = encoder_outputs[:, ocr_end:, :] # batch_size, dec_max_num, hidden_size
    # mmt_ocr_output = encoder_outputs[:, ocr_begin: ocr_end, :] # batch_size, ocr_max_num, hidden_size

    return encoder_outputs[:, ocr_begin: ocr_end, :], encoder_outputs[:, ocr_end:, :]

In [16]:
class OcrPtrNet(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, hidden_size)

    def forward(self, query_inputs, key_inputs, attention_mask):

        # query_layer = self.query(query_inputs) # batch_size, dec_max_num, hidden_size
        # key_layer = self.key(key_inputs) # batch_size, ocr_max_num, hidden_size

        scores = self.query(query_inputs) @ self.key(key_inputs).mT
        scores = scores / math.sqrt(self.hidden_size) # batch_size, dec_max_num, ocr_max_num

        scores = scores.masked_fill(attention_mask[:, None, :] == 0, -1e4)
        return scores

In [17]:
class M4C(nn.Module):
  def __init__(self,
               obj_in_dim,
               ocr_in_dim,
               hidden_size,
               n_heads,
               d_k,
               n_layers,
               vocab_size,
               fixed_ans_emb):
    super().__init__()
    self.obj_encoder = ObjectEncoder(obj_in_dim=obj_in_dim, hidden_size=hidden_size)
    self.ocr_encoder = OCREncoder(ocr_in_dim=ocr_in_dim, hidden_size=hidden_size)
    self.mmt = MMT(d_model=hidden_size, n_heads=n_heads, d_k=d_k, n_layers=n_layers)
    self.ocr_ptn = OcrPtrNet(hidden_size=hidden_size)
    self.classifier = nn.Linear(hidden_size, vocab_size)
    self.fixed_ans_emb = fixed_ans_emb
    self.finetune_modules = [{"module": self.obj_encoder.linear_obj_feat_to_mmt_in, "lr_scale": 0.1},
                             {"module": self.ocr_encoder.linear_ocr_feat_to_mmt_in, "lr_scale": 0.1},
                             {"module": self.mmt, "lr_scale": 1}]

  def get_optimizer_parameters(self, base_lr):
        optimizer_param_groups = []

        # collect all the parameters that need different/scaled lr
        finetune_params_set = set()
        for m in self.finetune_modules:
            optimizer_param_groups.append(
                {
                    "params": list(m["module"].parameters()),
                    "lr": base_lr * m["lr_scale"],
                }
            )
            finetune_params_set.update(list(m["module"].parameters()))
        # remaining_params are those parameters w/ default lr
        remaining_params = [
            p for p in self.parameters() if p not in finetune_params_set
        ]
        # put the default lr parameters at the beginning
        # so that the printed lr (of group 0) matches the default lr
        optimizer_param_groups.insert(0, {"params": remaining_params})

        return optimizer_param_groups

  def forward(self, sample, device='cpu'):

    obj_emb = self.obj_encoder(sample['obj_boxes'].to(device), sample['obj_features'].to(device))
    ocr_emb = self.ocr_encoder(sample['ocr_boxes'].to(device),
                               sample['ocr_token_embeddings'].to(device),
                               sample['ocr_rec_features'].to(device),
                               sample['ocr_det_features'].to(device))

    # Create decoder inputs
    dec_input = sample['labels'].clone().detach().roll(shifts=1, dims=1)
    dec_input[:, 0] = 0 # <s> token

    mmt_ocr_output, mmt_dec_output = self.mmt(obj_emb,
                                              self.fixed_ans_emb,
                                              ocr_emb,
                                              dec_input.to(device),
                                              sample['join_attn_mask'].to(device))

    ocr_begin = obj_emb.size(1)
    ocr_end = obj_emb.size(1) + mmt_ocr_output.size(1)

    # dynamic_ocr_scores = self.ocr_ptn(mmt_dec_output, mmt_ocr_output, sample['join_attn_mask'][:, ocr_begin: ocr_end].to(device))
    # fixed_scores = self.classifier(mmt_dec_output)

    scores = torch.cat([self.classifier(mmt_dec_output),
                        self.ocr_ptn(mmt_dec_output, mmt_ocr_output, sample['join_attn_mask'][:, ocr_begin: ocr_end].to(device))],
                        dim=-1)

    return scores

In [18]:
def convert_prediction_to_ans(prediction_list, sample_texts, tokenizer):
    raw_predicts = []

    for l, text in zip(prediction_list, sample_texts):
        raw_predict = []
        vocab_predict = []

        for w in l:
            if w > 64000: # Kiểm tra nếu w là ocr token
                if len(vocab_predict) > 0: # Kiểm tra nếu trước đó đã có các vocab token
                    raw_predict.append(tokenizer.decode(vocab_predict))
                    vocab_predict = []

                raw_predict.append(text[w - 64001])
            else:
                vocab_predict += [w]

        if len(vocab_predict) > 0:
                raw_predict.append(tokenizer.decode(vocab_predict))

        caption = ' '.join(raw_predict)
        raw_predicts.append(caption)

    return raw_predicts

In [19]:
def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = checkpoint['model']
    model.load_state_dict(checkpoint['state_dict'])
    for parameter in model.parameters():
        parameter.requires_grad = False

    model.eval()
    return model

In [20]:
model = load_checkpoint('/kaggle/input/vitextcapscheckpointfinalform/checkpoint (1).pth')

In [21]:
sample = next(iter(test_dataloader))

In [22]:
def decode_prediction(obj_boxes,
                      obj_features,
                      ocr_boxes,
                      ocr_tokens,
                      ocr_token_embeddings,
                      ocr_rec_features,
                      ocr_det_features, 
                      join_obj_ocr_mask, 
                      tokenizer, 
                      fixed_ans_emb):
    
    pred_ans = [tokenizer.bos_token_id]
    dec_input = torch.tensor([pred_ans])

    obj_emb = model.obj_encoder(obj_boxes=obj_boxes.unsqueeze(0),
                                obj_features=obj_features.unsqueeze(0))

    ocr_emb = model.ocr_encoder(ocr_boxes=ocr_boxes.unsqueeze(0),
                                ocr_token_embeddings=ocr_token_embeddings.unsqueeze(0),
                                ocr_rec_features=ocr_rec_features.unsqueeze(0),
                                ocr_det_features=ocr_det_features.unsqueeze(0))

    join_obj_ocr_mask = join_obj_ocr_mask.unsqueeze(0)
    ocr_begin = obj_emb.size(1)
    ocr_end = ocr_begin + ocr_emb.size(1)
    
    for i in tqdm(range(50)):
    
        dec_mask = torch.ones((1, dec_input.size(1)), dtype=torch.int64)
        attn_mask = torch.cat([join_obj_ocr_mask, dec_mask], dim=-1)

        mmt_ocr_output, mmt_dec_output = model.mmt(obj_emb,
                                                  fixed_ans_emb,
                                                  ocr_emb,
                                                  dec_input,
                                                  attn_mask)

#         dynamic_ocr_scores = model.ocr_ptn(mmt_dec_output, mmt_ocr_output, join_obj_ocr_mask[0, ocr_begin: ocr_end].unsqueeze(0))
        fixed_scores = model.classifier(mmt_dec_output)
#         scores = torch.cat([fixed_scores, dynamic_ocr_scores], dim=-1)
        scores = fixed_scores
        scores = scores.argmax(dim=-1)
        scores = scores[0][-1].item()

        if scores == tokenizer.eos_token_id:
            break

        pred_ans.append(scores)
        dec_input = torch.tensor([pred_ans])
        
    decode_ans = convert_prediction_to_ans([pred_ans], [ocr_tokens], tokenizer)
    return decode_ans

In [23]:
sample = next(iter(test_dataloader))

In [24]:
i = 10 
obj_length = sample['obj_boxes'][i].size(0)
ocr_length = sample['ocr_boxes'][i].size(0)

decode_prediction(obj_boxes=sample['obj_boxes'][i],
                  obj_features=sample['obj_features'][i],
                  ocr_boxes=sample['ocr_boxes'][i],
                  ocr_tokens=sample['texts'][i],
                  ocr_token_embeddings=sample['ocr_token_embeddings'][i],
                  ocr_rec_features=sample['ocr_rec_features'][i],
                  ocr_det_features=sample['ocr_det_features'][i],
                  join_obj_ocr_mask=sample['join_attn_mask'][i, :obj_length + ocr_length], 
                  tokenizer=tokenizer, 
                  fixed_ans_emb=phobert_model.embeddings.word_embeddings.weight.data)

  0%|          | 0/50 [00:00<?, ?it/s]

['<s> mọi người đang mua sắm.']

In [80]:
obj_length = sample['obj_boxes'][0].size(0)
ocr_length = sample['ocr_boxes'][0].size(0)
sample['join_attn_mask'][0, : obj_length + ocr_length].shape

torch.Size([121])

In [78]:
ocr_length

85

In [79]:
obj_length

36

In [95]:
# Assuming you have the following sequence lengths for each modality
txt_max_num = 10
obj_max_num = 15
ocr_max_num = 12
dec_max_num = 5

# Assuming a batch size of 2 for demonstration purposes
batch_size = 2

# Creating random binary masks for each modality
txt_mask = torch.randint(2, size=(batch_size, txt_max_num))
obj_mask = torch.randint(2, size=(batch_size, obj_max_num))
ocr_mask = torch.randint(2, size=(batch_size, ocr_max_num))

# Creating a zero mask for decoding steps
dec_mask = torch.zeros(batch_size, dec_max_num)

In [99]:
def _get_causal_mask(seq_length, device):
    # generate a lower triangular mask
    mask = torch.zeros(seq_length, seq_length, device=device)
    for i in range(seq_length):
        for j in range(i + 1):
            mask[i, j] = 1.0
    return mask

def generate_extended_attention_mask(attention_mask, dec_max_num, device):
    to_seq_length = attention_mask.size(1)
    from_seq_length = to_seq_length

    # generate the attention mask similar to prefix LM
    # all elements can attend to the elements in encoding steps
    extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
    extended_attention_mask = extended_attention_mask.repeat(1, 1, from_seq_length, 1)

    # decoding step elements can attend to themselves in a causal manner
    extended_attention_mask[:, :, -dec_max_num:, -dec_max_num:] = _get_causal_mask(
        dec_max_num, device
    )

    # flip the mask, so that invalid attention pairs have -10000.
    extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

    return extended_attention_mask

# Example usage:
# Assuming you have the necessary inputs: txt_mask, obj_mask, ocr_mask, dec_mask
# (you can replace these with the actual masks from your data)
# and dec_max_num (the number of decoding steps)
attention_mask = torch.cat([txt_mask, obj_mask, ocr_mask, dec_mask], dim=1)
extended_attention_mask = generate_extended_attention_mask(
    attention_mask, dec_max_num, 'cpu')

In [101]:
extended_attention_mask[0]

tensor([[[    -0., -10000., -10000.,  ..., -10000., -10000., -10000.],
         [    -0., -10000., -10000.,  ..., -10000., -10000., -10000.],
         [    -0., -10000., -10000.,  ..., -10000., -10000., -10000.],
         ...,
         [    -0., -10000., -10000.,  ...,     -0., -10000., -10000.],
         [    -0., -10000., -10000.,  ...,     -0.,     -0., -10000.],
         [    -0., -10000., -10000.,  ...,     -0.,     -0.,     -0.]]])