In [136]:
import numpy as np
import torch
import torch.nn as nn
import math

In [137]:
def _get_mask(nums, max_num):
    # non_pad_mask: b x lq, torch.float32, 0. on PAD
    batch_size = nums.size(0)
    arange = torch.arange(0, max_num).unsqueeze(0).expand(batch_size, -1)
    non_pad_mask = arange.to(nums.device).lt(nums.unsqueeze(-1))
    non_pad_mask = non_pad_mask.type(torch.float32)
    return non_pad_mask

class OcrPtrNet(nn.Module):
    def __init__(self, hidden_size, query_key_size=None):
        super().__init__()

        if query_key_size is None:
            query_key_size = hidden_size
        self.hidden_size = hidden_size
        self.query_key_size = query_key_size

        self.query = nn.Linear(hidden_size, query_key_size)
        self.key = nn.Linear(hidden_size, query_key_size)

    def forward(self, query_inputs, key_inputs, attention_mask):
        extended_attention_mask = (1.0 - attention_mask) * -10000.0
        assert extended_attention_mask.dim() == 2
        extended_attention_mask = extended_attention_mask.unsqueeze(1)

        query_layer = self.query(query_inputs)
        if query_layer.dim() == 2:
            query_layer = query_layer.unsqueeze(1)
            squeeze_result = True
        else:
            squeeze_result = False
        key_layer = self.key(key_inputs)

        scores = torch.matmul(
            query_layer,
            key_layer.transpose(-1, -2)
        )
        scores = scores / math.sqrt(self.query_key_size)
        scores = scores + extended_attention_mask
        if squeeze_result:
            scores = scores.squeeze(1)

        return scores



In [138]:
ocr_nums = torch.randint(20, (1,8))[0] + 1
ocr_nums

tensor([ 5,  5, 13,  3, 13,  9, 18, 19])

In [139]:
ocr_mask = _get_mask(ocr_nums, 50)
ocr_mask[0]

tensor([1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [140]:
e_mask = (1.0 - ocr_mask) * -10000.0
assert e_mask.dim() == 2

In [141]:
e_mask

tensor([[    -0.,     -0.,     -0.,     -0.,     -0., -10000., -10000., -10000.,
         -10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000.,
         -10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000.,
         -10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000.,
         -10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000.,
         -10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000.,
         -10000., -10000.],
        [    -0.,     -0.,     -0.,     -0.,     -0., -10000., -10000., -10000.,
         -10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000.,
         -10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000.,
         -10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000.,
         -10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000.,
         -10000., -10000., -10000., -10000., -10000., -10000., -10000., -10000.,


In [165]:
bs = 32

In [166]:
top_out = torch.randint(20, (bs, 170, 768)) + 1
top_out = top_out.float()
top_out.shape

torch.Size([32, 170, 768])

In [167]:
bert_dec_output = top_out[:, 0:40, :] # bert output text dec
bert_ocr_output = top_out[:, 120:170, :] # bert out ocr dec
(bert_dec_output.shape, bert_ocr_output.shape)

(torch.Size([32, 40, 768]), torch.Size([32, 50, 768]))

In [168]:
ocr_nums = torch.randint(666, (1, bs))[0] + 1
ocr_nums

# binary mask of valid OCR vs padding
#ocr_nums = len(input_ocr_ids)  # count of ocr tokens found in image


tensor([209, 183, 238, 164, 569, 358,  76, 129, 247, 633, 331, 223, 125, 492,
        353, 114,  86, 147, 426, 297,  86, 552, 345, 551, 656, 284, 591, 658,
        220, 573, 627, 423])

In [169]:
## pad at the end; used anyway by obj, ocr mmt encode
ocr_mask = _get_mask(ocr_nums, 50)

In [170]:
ocr_ptr_net = OcrPtrNet(768)

In [171]:



#fixed_scores = self.cls(top_out[:, :input_ids.shape[-1], :]  # (1, 0:70, 768))
dynamic_ocr_scores = ocr_ptr_net(bert_dec_output, bert_ocr_output, ocr_mask)
#scores = torch.cat([fixed_scores, dynamic_ocr_scores], dim=-1)

In [172]:
dynamic_ocr_scores.shape

torch.Size([32, 40, 50])

In [243]:
pad_pos = torch.Tensor([0, 0, 0, 0, 0, 0])

def get_sample():
    k = np.random.randint(35) + 1
    print(k, end=' ')
    empty_pos = torch.vstack([pad_pos for i in range(50 - k)])
    ocr_pos = torch.vstack([torch.rand_like(pad_pos) for i in range(k)])
    ocr_input = torch.cat([ocr_pos, empty_pos])
    return ocr_input

In [244]:
a = torch.stack([get_sample() for i in range(32)])
a.shape

9 32 25 12 2 32 27 15 23 20 29 12 16 9 26 6 21 19 29 27 23 5 9 5 31 15 16 8 11 21 29 17 

torch.Size([32, 50, 6])

In [246]:
c = torch.sum(a != pad_pos, dim=1)[:, 0]
#c.shape
c

tensor([ 9, 32, 25, 12,  2, 32, 27, 15, 23, 20, 29, 12, 16,  9, 26,  6, 21, 19,
        29, 27, 23,  5,  9,  5, 31, 15, 16,  8, 11, 21, 29, 17])