In [3]:
import math
import torch
from torch import nn
from torch.nn.functional import gelu
from transformers.modeling_bert import BertLayerNorm
from transformers import BertTokenizer
from lxrt.modeling import BertPreTrainedModel

I0123 03:43:41.146125 139748229315456 file_utils.py:35] PyTorch version 1.4.0+cpu available.
  from ._conv import register_converters as _register_converters


In [4]:
from lxrt.entry import InputFeatures,convert_sents_to_features,set_visual_config
from lxrt.modeling import VISUAL_CONFIG

In [5]:
from transformers import BertConfig
bert_config = BertConfig()

In [6]:
class GeLU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,x):
        return torch.nn.functional.gelu(x)

In [7]:
class Args():
    def __init__(self,l_layers,x_layers,r_layers):
        self.llayers = l_layers
        self.xlayers = x_layers
        self.rlayers = r_layers
        self.from_scratch=False
args = Args(9,5,5)
MAX_VQA_LENGTH = 20

In [8]:
## BertEmbeddings
class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """
    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size, padding_idx=0)

        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids, token_type_ids=None):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings
    
from transformers import BertConfig
bert_embeddings = BertEmbeddings(BertConfig())
output = bert_embeddings(input_ids = torch.rand(128,20).long(),
                         token_type_ids = torch.rand(128,20).long())
output.shape

torch.Size([128, 20, 768])

In [13]:
num_attention_heads = bert_config.num_attention_heads # 12
attention_head_size = int(bert_config.hidden_size / bert_config.num_attention_heads) # 768/12
all_head_size = num_attention_heads * attention_head_size # 12*64

ctx_dim = bert_config.hidden_size
query = nn.Linear(bert_config.hidden_size, all_head_size) # 768x768
key = nn.Linear(ctx_dim, all_head_size) # 768x768
value = nn.Linear(ctx_dim, all_head_size) # 768x768

dropout = nn.Dropout(bert_config.attention_probs_dropout_prob)

def transpose_for_scores(x):
    new_x_shape = x.size()[:-1] + (num_attention_heads, attention_head_size)
    x = x.view(*new_x_shape)
    return x.permute(0, 2, 1, 3)

In [14]:
hidden_states = torch.rand(128,36,768)
context = torch.rand(128,20,768)
attention_mask = torch.rand(128,1,1,20)

In [15]:
q = query(hidden_states)
print(q.shape)
k = key(context)
print(k.shape)
v = value(context)
print(v.shape)

torch.Size([128, 36, 768])
torch.Size([128, 20, 768])
torch.Size([128, 20, 768])


In [16]:
q = transpose_for_scores(q)
print(q.shape)
k = transpose_for_scores(k)
print(k.shape)
v = transpose_for_scores(v)
print(v.shape)

torch.Size([128, 12, 36, 64])
torch.Size([128, 12, 20, 64])
torch.Size([128, 12, 20, 64])


In [17]:
k.transpose(-1, -2).shape

torch.Size([128, 12, 64, 20])

In [18]:
attention_scores = torch.matmul(q, k.transpose(-1, -2)) # [128, 12, 20, 36]
attention_scores = attention_scores / math.sqrt(attention_head_size)
print(attention_scores.shape)

torch.Size([128, 12, 36, 20])


In [19]:
attention_scores = attention_scores + attention_mask

In [7]:
## BertAttention

class BertAttention(nn.Module):
    """
    from transformers import BertConfig
    
    bert_att = BertAttention(BertConfig())
    context_output = bert_att(hidden_states = torch.rand(128,20,768),
                          context = torch.rand(128,36,768),
                          attention_mask = None)
    context_output.shape # [128, 20, 768]

    """
    def __init__(self, config, ctx_dim=None):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))
        self.num_attention_heads = config.num_attention_heads # 12
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 768/12
        self.all_head_size = self.num_attention_heads * self.attention_head_size # 12*64

        # visual_dim = 2048
        if ctx_dim is None:
            ctx_dim =config.hidden_size
        self.query = nn.Linear(config.hidden_size, self.all_head_size) # 768x768
        self.key = nn.Linear(ctx_dim, self.all_head_size) # 768x768
        self.value = nn.Linear(ctx_dim, self.all_head_size) # 768x768

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, context, attention_mask=None):
        
        mixed_query_layer = self.query(hidden_states) # [128,36,768]
        mixed_key_layer = self.key(context) # [128,20,768]
        mixed_value_layer = self.value(context) #[128,20,768]
        
        query_layer = self.transpose_for_scores(mixed_query_layer) # [128, 12, 20, 64]
        key_layer = self.transpose_for_scores(mixed_key_layer) # [128, 12, 36, 64]
        value_layer = self.transpose_for_scores(mixed_value_layer) # [128, 12, 36, 64]

        # Take the dot product between "query" and "key" to get the raw attention scores.
        # print('key_layer_transpose: ', key_layer.transpose(-1,-2).shape) : [128, 12, 64, 36]
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [128, 12, 20, 36]
        attention_scores = attention_scores / math.sqrt(self.attention_head_size) # [128, 12, 20, 36]
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        if attention_mask is not None:
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores) # [128, 12, 20, 36]

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        
        attention_probs = self.dropout(attention_probs) # [128, 12, 20, 36]
                         #  [128, 12, 20, 36] x [128, 12, 36, 64]
        context_layer = torch.matmul(attention_probs, value_layer) # [128, 12, 20, 64]
        
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # [128, 20, 12, 64]
        
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) # [128, 20, 768]
        
        context_layer = context_layer.view(*new_context_layer_shape) # [128, 20, 768]
        return context_layer
    
from transformers import BertConfig
bert_att = BertAttention(BertConfig())
context_output = bert_att(hidden_states = torch.rand(128,20,768),
                          context = torch.rand(128,36,768),
                          attention_mask = None)
context_output.shape

torch.Size([128, 20, 768])

In [8]:
## BertAttOutput

class BertAttOutput(nn.Module):
    """
    from transformers import BertConfig
    bert_att_output = BertAttOutput(BertConfig())
    output = bert_att_output(torch.rand(128,20,768),torch.rand(128,20,768))
    output.shape [128,20,768]

    """
    def __init__(self, config):
        super(BertAttOutput, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

from transformers import BertConfig
bert_att_output = BertAttOutput(BertConfig())
output = bert_att_output(torch.rand(128,20,768),torch.rand(128,20,768))
output.shape

torch.Size([128, 20, 768])

In [9]:
## BertCross Attention
class BertCrossattLayer(nn.Module):
    """
    from transformers import BertConfig
    
    bert_cross_att = BertCrossattLayer(BertConfig())
    output = bert_cross_att(input_tensor = torch.rand(128,20,768), 
                        ctx_tensor = torch.rand(128,36,768), 
                        ctx_att_mask = None)
                        
    output.shape [128,20,768]
    """
    def __init__(self, config):
        super().__init__()
        self.att = BertAttention(config)
        self.output = BertAttOutput(config)

    def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None):
        output = self.att(input_tensor, ctx_tensor, ctx_att_mask) # [128,20,768]
        attention_output = self.output(output, input_tensor)
        return attention_output
    
from transformers import BertConfig
bert_cross_att = BertCrossattLayer(BertConfig())
output = bert_cross_att(input_tensor = torch.rand(128,20,768), 
                        ctx_tensor = torch.rand(128,36,768), 
                        ctx_att_mask = None)
output.shape

torch.Size([128, 20, 768])

In [10]:
## BertSelfattLayer

class BertSelfattLayer(nn.Module):
    """
    bert_self_att_layer = BertSelfattLayer(bert_config)
    output = bert_self_att_layer(input_tensor = torch.rand(128,20,768),
                             attention_mask = torch.rand(128,1,1,20))
    output.shape [128, 20, 768]
    """
    def __init__(self, config):
        super(BertSelfattLayer, self).__init__()
        self.self = BertAttention(config)
        self.output = BertAttOutput(config)

    def forward(self, input_tensor, attention_mask):
        # Self attention attends to itself, thus keys and querys are the same (input_tensor).
        self_output = self.self(input_tensor, input_tensor, attention_mask)
        attention_output = self.output(self_output, input_tensor)
        return attention_output
    
bert_self_att_layer = BertSelfattLayer(bert_config)
output = bert_self_att_layer(input_tensor = torch.rand(128,20,768),
                             attention_mask = torch.rand(128,1,1,20))
output.shape

torch.Size([128, 20, 768])

In [11]:
## BertIntermediate

class BertIntermediate(nn.Module):
    """
    bert_intermediate = BertIntermediate(bert_config)
    output = bert_intermediate(torch.rand(128,20,768))
    output.shape # [128,20,3072]

    """
    def __init__(self, config):
        super(BertIntermediate, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
            self.intermediate_act_fn = GeLU()
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)
        return hidden_states

bert_intermediate = BertIntermediate(bert_config)
output = bert_intermediate(torch.rand(128,20,768))
output.shape

torch.Size([128, 20, 3072])

In [12]:
## BertOutput

class BertOutput(nn.Module):
    """
    bert_output = BertOutput(bert_config)
    output = bert_output(hidden_states = torch.rand(128,20,3072),
                         input_tensor = torch.rand(128,20,768))
    output.shape # [128,20,768]

    """
    def __init__(self, config):
        super(BertOutput, self).__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size) # [3072x768]
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

bert_output = BertOutput(bert_config)
output = bert_output(hidden_states = torch.rand(128,20,3072),
                         input_tensor = torch.rand(128,20,768))
output.shape

torch.Size([128, 20, 768])

In [13]:
## BertLayer

class BertLayer(nn.Module):
    """
    from transformers import BertConfig
    bert_layer  = BertLayer(BertConfig())
    output = bert_layer(torch.rand(128,20,768),torch.rand(128,1,1,20))
    output.shape [128,20,768]
    """
    def __init__(self, config):
        super(BertLayer, self).__init__()
        self.attention = BertSelfattLayer(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

    def forward(self, hidden_states, attention_mask):
        attention_output = self.attention(hidden_states, attention_mask) # [128, 20, 768]
        intermediate_output = self.intermediate(attention_output)
        # [128,20,3072], [128,20,768]
        layer_output = self.output(intermediate_output, attention_output) # [128,20,768]
        return layer_output
    
from transformers import BertConfig
bert_layer  = BertLayer(BertConfig())
output = bert_layer(hidden_states = torch.rand(128,20,768),
                    attention_mask = torch.rand(128,1,1,20))
output.shape

torch.Size([128, 20, 768])

## LXRTXLayer

In [14]:
class LXRTXLayer(nn.Module):
    """
    from transformers import BertConfig
    lxrtx_layer = LXRTXLayer(BertConfig())
    output = lxrtx_layer(lang_feats = torch.rand(128,20,768),
                      lang_attention_mask = torch.rand(128,1,1,20),
                      visn_feats = torch.rand(128,36,768),
                      visn_attention_mask = None)

    lang_output.shape: [128,20,768]
    visn_output.shape: [128,36,768]
    
    """
    def __init__(self, config):
        super().__init__()
        # The cross-attention Layer
        self.visual_attention = BertCrossattLayer(config)

        # Self-attention Layers
        self.lang_self_att = BertSelfattLayer(config)
        self.visn_self_att = BertSelfattLayer(config)

        # Intermediate and Output Layers (FFNs)
        self.lang_inter = BertIntermediate(config)
        self.lang_output = BertOutput(config)
        self.visn_inter = BertIntermediate(config)
        self.visn_output = BertOutput(config)

    def cross_att(self, lang_input, lang_attention_mask, visn_input, visn_attention_mask):
        # Cross Attention
        lang_att_output = self.visual_attention(lang_input, visn_input, ctx_att_mask=visn_attention_mask)
        visn_att_output = self.visual_attention(visn_input, lang_input, ctx_att_mask=lang_attention_mask)
        return lang_att_output, visn_att_output

    def self_att(self, lang_input, lang_attention_mask, visn_input, visn_attention_mask):
        # Self Attention
        lang_att_output = self.lang_self_att(lang_input, lang_attention_mask)
        visn_att_output = self.visn_self_att(visn_input, visn_attention_mask)
        return lang_att_output, visn_att_output

    def output_fc(self, lang_input, visn_input):
        # FC layers
        lang_inter_output = self.lang_inter(lang_input)
        visn_inter_output = self.visn_inter(visn_input)

        # Layer output
        lang_output = self.lang_output(lang_inter_output, lang_input)
        visn_output = self.visn_output(visn_inter_output, visn_input)
        return lang_output, visn_output

    def forward(self, lang_feats, lang_attention_mask,
                      visn_feats, visn_attention_mask):
        lang_att_output = lang_feats
        visn_att_output = visn_feats

        lang_att_output, visn_att_output = self.cross_att(lang_att_output, lang_attention_mask,
                                                          visn_att_output, visn_attention_mask)
        
        lang_att_output, visn_att_output = self.self_att(lang_att_output, lang_attention_mask,
                                                         visn_att_output, visn_attention_mask)
        lang_output, visn_output = self.output_fc(lang_att_output, visn_att_output)

        return lang_output, visn_output

In [15]:
from transformers import BertConfig
lxrtx_layer = LXRTXLayer(BertConfig())
output = lxrtx_layer(lang_feats = torch.rand(128,20,768),
                      lang_attention_mask = torch.rand(128,1,1,20),
                      visn_feats = torch.rand(128,36,768),
                      visn_attention_mask = None)
output[0].shape, output[1].shape

(torch.Size([128, 20, 768]), torch.Size([128, 36, 768]))

## VisualFeatEncoder

In [16]:
class VisualFeatEncoder(nn.Module):
    """

    from transformers import BertConfig
    
    visual_feat_encoder = VisualFeatEncoder(BertConfig())
    output = visual_feat_encoder((torch.rand(128,36,2048),torch.rand(128,36,4))) img_feats+box_feats
    
    output.shape: [128,36,768]
    """
    def __init__(self, config):
        super().__init__()
        feat_dim = VISUAL_CONFIG.visual_feat_dim
        pos_dim = VISUAL_CONFIG.visual_pos_dim

        # Object feature encoding
        self.visn_fc = nn.Linear(feat_dim, config.hidden_size)
        self.visn_layer_norm = BertLayerNorm(config.hidden_size, eps=1e-12)

        # Box position encoding
        self.box_fc = nn.Linear(pos_dim, config.hidden_size)
        self.box_layer_norm = BertLayerNorm(config.hidden_size, eps=1e-12)

        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, visn_input):
        feats, boxes = visn_input

        x = self.visn_fc(feats)
        x = self.visn_layer_norm(x)
        y = self.box_fc(boxes)
        y = self.box_layer_norm(y)
        output = (x + y) / 2

        output = self.dropout(output)
        return output

In [17]:
from transformers import BertConfig
visual_feat_encoder = VisualFeatEncoder(BertConfig())
output = visual_feat_encoder((torch.rand(128,36,2048),torch.rand(128,36,4)))
output.shape

torch.Size([128, 36, 768])

## LXRT Encoder

Uses VisualFeatEncoder along with BertLayer,LXRTXLayer

In [18]:
class LXRTEncoder(nn.Module):
    """
    from transformers import BertConfig
    lxrt_encoder = LXRTEncoder(BertConfig())

    output = lxrt_encoder(lang_feats = torch.rand(128,20,768),
                      lang_attention_mask = torch.rand(128,1,1,20),
                      visn_feats = (torch.rand(128,36,2048),torch.rand(128,36,4)),
                      visn_attention_mask = None)

    lang_feats.shape: [128,20,768]
    visn_feats.shape: [128,36,768]

    """
    def __init__(self, config):
        super().__init__()

        # Obj-level image embedding layer
        self.visn_fc = VisualFeatEncoder(config)

        # Number of layers
        self.num_l_layers = VISUAL_CONFIG.l_layers
        self.num_x_layers = VISUAL_CONFIG.x_layers
        self.num_r_layers = VISUAL_CONFIG.r_layers
        print("LXRT encoder with %d l_layers, %d x_layers, and %d r_layers." %
              (self.num_l_layers, self.num_x_layers, self.num_r_layers))

        # Layers
        # Using self.layer instead of self.l_layer to support loading BERT weights.
        self.layer = nn.ModuleList(
            [BertLayer(config) for _ in range(self.num_l_layers)]
        )
        self.x_layers = nn.ModuleList(
            [LXRTXLayer(config) for _ in range(self.num_x_layers)]
        )
        self.r_layers = nn.ModuleList(
            [BertLayer(config) for _ in range(self.num_r_layers)]
        )

    def forward(self, lang_feats, lang_attention_mask,
                visn_feats, visn_attention_mask=None):
        # Run visual embedding layer
        # Note: Word embedding layer was executed outside this module.
        #       Keep this design to allow loading BERT weights.
        visn_feats = self.visn_fc(visn_feats)
        #print('visn_feats_from_visn_fc', visn_feats.shape) : [128, 36, 768]

        # Run language layers
        for layer_module in self.layer:
            lang_feats = layer_module(lang_feats, lang_attention_mask)

        # Run relational layers
        for layer_module in self.r_layers:
            visn_feats = layer_module(visn_feats, visn_attention_mask)

        # Run cross-modality layers
        for layer_module in self.x_layers:
            lang_feats, visn_feats = layer_module(lang_feats, lang_attention_mask,
                                                  visn_feats, visn_attention_mask)

        return lang_feats, visn_feats

In [19]:
from transformers import BertConfig
lxrt_encoder = LXRTEncoder(BertConfig())

output = lxrt_encoder(lang_feats = torch.rand(128,20,768),
                      lang_attention_mask = torch.rand(128,1,1,20),
                      visn_feats = (torch.rand(128,36,2048),torch.rand(128,36,4)),
                      visn_attention_mask = None)

output[0].shape,output[1].shape

LXRT encoder with 12 l_layers, 5 x_layers, and 0 r_layers.


(torch.Size([128, 20, 768]), torch.Size([128, 36, 768]))

In [20]:
class BertPooler(nn.Module):
    def __init__(self, config):
        super(BertPooler, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

## LXRTModel

In [21]:
class LXRTModel(BertPreTrainedModel):
    """
    LXRT Model.
    
    model = LXRTModel.from_pretrained("bert-base-uncased")
    
    output = model(input_ids = torch.rand(128,20).long(), 
               token_type_ids = torch.rand(128,20).long(),
               attention_mask = torch.rand(128,20).long(),
               visual_feats = (torch.rand(128,36,2048),torch.rand(128,36,4)),
               visual_attention_mask = None)
    
    
    lang_feats.shape -> [128, 20, 768]
    vision_feats.shape -> [128, 36, 768]
    pooled_output.shape -> [128,768]
    
    """

    def __init__(self, config):
        super().__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = LXRTEncoder(config)
        self.pooler = BertPooler(config)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None,
                visual_feats=None, visual_attention_mask=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        
        #print('Attention Mask', attention_mask.shape) : [128, 20]
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        
        # print('Extended Attention Mask', extended_attention_mask.shape): [128, 1, 1, 20]
        
        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        
        # print('Extended Attention Mask 1k', extended_attention_mask.shape): [128, 1, 1, 20]
        
        # Process the visual attention mask
        if visual_attention_mask is not None:
            extended_visual_attention_mask = visual_attention_mask.unsqueeze(1).unsqueeze(2)
            extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
            extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
        else:
            extended_visual_attention_mask = None
        
        # print('Extended Visual Attention Mask', extended_visual_attention_mask.shape) Shape: None
        
        # Positional Word Embeddings
        embedding_output = self.embeddings(input_ids, token_type_ids)
        
        # print('Embedding Output', embedding_output.shape): [128,20,768]
        
        # Run LXRT backbone
        
        
        lang_feats, visn_feats = self.encoder(
            embedding_output,
            extended_attention_mask,
            visn_feats=visual_feats,
            visn_attention_mask=extended_visual_attention_mask)
        
        pooled_output = self.pooler(lang_feats)

        return (lang_feats, visn_feats), pooled_output

In [23]:
model = LXRTModel.from_pretrained("bert-base-uncased")
output = model(input_ids = torch.rand(128,20).long(), 
               token_type_ids = torch.rand(128,20).long(),
               attention_mask = torch.rand(128,20).long(),
               visual_feats = (torch.rand(128,36,2048),torch.rand(128,36,4)),
               visual_attention_mask = None)
output[0][0].shape, output[0][1].shape, output[1].shape

I0111 03:19:40.789950 140547482499968 modeling.py:760] loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at /home/u37216/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
I0111 03:19:40.847867 140547482499968 modeling.py:768] extracting archive file /home/u37216/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /home/u37216/tmp/tmpstli2ble
I0111 03:19:48.006472 140547482499968 modeling.py:775] Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 3052

LXRT encoder with 12 l_layers, 5 x_layers, and 0 r_layers.


(torch.Size([128, 20, 768]),
 torch.Size([128, 36, 768]),
 torch.Size([128, 768]))

## LXRT Feature Extraction

Takes input_ids, token_type_ids, attention_mask, (visual features+box features), visual_attention_mask

In [26]:
class VisualBertForLXRFeature(BertPreTrainedModel):
    """
    BERT model for classification.
    
    bert = VisualBertForLXRFeature.from_pretrained("bert-base-uncased",mode='x')
    
    output = bert(input_ids = torch.rand(128,20).long(), 
              token_type_ids = torch.rand(128,20).long(),
              attention_mask = torch.rand(128,20).long(),
              visual_feats = (torch.rand(128,36,2048),torch.rand(128,36,4)), # for feats and boxes
              visual_attention_mask = None)
              
    output.shape -> [128,768]          ,
    """
    def __init__(self, config, mode='lxr'):
        """
        :param config:
        :param mode:  Number of visual layers
        """
        super().__init__(config)
        self.bert = LXRTModel(config)
        self.mode = mode
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, visual_feats=None,
                visual_attention_mask=None):
        feat_seq, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
                                            visual_feats=visual_feats,
                                            visual_attention_mask=visual_attention_mask)
        if 'x' == self.mode:
            return pooled_output
        elif 'x' in self.mode and ('l' in self.mode or 'r' in self.mode):
            return feat_seq, pooled_output
        elif 'l' in self.mode or 'r' in self.mode:
            return feat_seq

In [27]:
bert = VisualBertForLXRFeature.from_pretrained("bert-base-uncased",mode='x')
output = bert(input_ids = torch.rand(128,20).long(), 
              token_type_ids = torch.rand(128,20).long(),
              attention_mask = torch.rand(128,20).long(),
              visual_feats = (torch.rand(128,36,2048),torch.rand(128,36,4)),
              visual_attention_mask = None)
output.shape

LXRT encoder with 12 l_layers, 5 x_layers, and 0 r_layers.


torch.Size([128, 768])

## LXRT Encoder

Uses tokenizer and VisualBertForLXRFeature(input) 

In [28]:
class LXRTEncoder_(nn.Module):
    """
    Usage:
        Input:
            lxrt_encoder = LXRTEncoder(args,MAX_VQA_LENGTH=20).cuda()
            feat = torch.rand(128,36,2048).cuda()
            pos = torch.rand(128,36,4).cuda()
            sent = list(sentences) # len(sent) = batch_size i.e 128
        
        Output:
            output = lxrt_encoder(sent, (feat.cuda(), pos.cuda())) # [128,768]
    """
    def __init__(self, args, max_seq_length, mode='x'):
        super().__init__()
        self.max_seq_length = max_seq_length
        set_visual_config(args)

        # Using the bert tokenizer
        self.tokenizer = BertTokenizer.from_pretrained(
            "bert-base-uncased",
            do_lower_case=True
        )

        # Build LXRT Model
        self.model = VisualBertForLXRFeature.from_pretrained(
            "bert-base-uncased",
            mode=mode
        )

        if args.from_scratch:
            print("initializing all the weights")
            self.model.apply(self.model.init_bert_weights)

    def multi_gpu(self):
        self.model = nn.DataParallel(self.model)

    @property
    def dim(self):
        return 768

    def forward(self, sents, feats, visual_attention_mask=None):
        
        train_features = convert_sents_to_features(
            sents, self.max_seq_length, self.tokenizer)

        input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long).cuda()
        input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long).cuda()
        segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long).cuda()
        
        #input_ids.shape, [128,20]
        #input_mask.shape, [128,20]
        #segment_ids.shape [128,20]
        
        output = self.model(input_ids, segment_ids, input_mask,
                            visual_feats=feats,
                            visual_attention_mask=visual_attention_mask)
        return output

    def save(self, path):
        torch.save(self.model.state_dict(),
                   os.path.join("%s_LXRT.pth" % path))

    def load(self, path):
        # Load state_dict from snapshot file
        print("Load LXMERT pre-trained model from %s" % path)
        state_dict = torch.load("%s_LXRT.pth" % path)
        new_state_dict = {}
        for key, value in state_dict.items():
            if key.startswith("module."):
                new_state_dict[key[len("module."):]] = value
            else:
                new_state_dict[key] = value
        state_dict = new_state_dict

        # Print out the differences of pre-trained and model weights.
        load_keys = set(state_dict.keys())
        model_keys = set(self.model.state_dict().keys())
        print()
        print("Weights in loaded but not in model:")
        for key in sorted(load_keys.difference(model_keys)):
            print(key)
        print()
        print("Weights in model but not in loaded:")
        for key in sorted(model_keys.difference(load_keys)):
            print(key)
        print()

        # Load weights to model
        self.model.load_state_dict(state_dict, strict=False)


In [31]:
lxrt_encoder = LXRTEncoder_(args,MAX_VQA_LENGTH).cuda()
output = lxrt_encoder(sent, (feat.cuda(), pos.cuda()))
output.shape

torch.Size([128, 768])

## VQAModel

In [32]:
class VQAModel(nn.Module):
    def __init__(self, num_answers):
        super().__init__()
        
        # Build LXRT encoder
        self.lxrt_encoder = LXRTEncoder_(
            args,
            max_seq_length=MAX_VQA_LENGTH
        )
        hid_dim = self.lxrt_encoder.dim
        
        # VQA Answer heads
        self.logit_fc = nn.Sequential(
            nn.Linear(hid_dim, hid_dim * 2),
            GeLU(),
            BertLayerNorm(hid_dim * 2, eps=1e-12),
            nn.Linear(hid_dim * 2, num_answers)
        )
        self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights)

    def forward(self, feat, pos, sent):
        """
        b -- batch_size, o -- object_number, f -- visual_feature_size
        :param feat: (b, o, f) # [128, 36, 2048]
        :param pos:  (b, o, 4) # [128, 36, 4]
        :param sent: (b,) Type -- list of string # 128
        :param leng: (b,) Type -- int numpy array # [128, 3129]
        :return: (b, num_answer) The logit of each answers.
        """
        x = self.lxrt_encoder(sent, (feat, pos))
        logit = self.logit_fc(x)

        return logit

In [33]:
model = VQAModel(3129).cuda()

LXRT encoder with 9 l_layers, 5 x_layers, and 5 r_layers.


In [34]:
model(feat.cuda(),pos.cuda(),sent)

tensor([[-0.8473,  1.4354, -1.5965,  ..., -0.6111, -0.7988,  2.1134],
        [-0.7456,  1.3032, -1.2911,  ..., -0.6765, -0.6070,  2.2290],
        [-0.5259,  1.6048, -1.3429,  ..., -0.4054, -0.6350,  2.0440],
        ...,
        [-0.7909,  1.5701, -1.4899,  ..., -0.4096, -0.4184,  2.0882],
        [-0.8287,  1.6658, -1.5937,  ..., -0.4897, -0.7288,  2.0589],
        [-0.4265,  1.4685, -1.0218,  ..., -1.1383, -0.2534,  2.4387]],
       device='cuda:0', grad_fn=<AddmmBackward>)

## Initializing Inputs

In [30]:
feat = torch.rand(128,36,2048).cuda()
pos = torch.rand(128,36,4).cuda()

sent=['Are there any people in this photo?',
 'How many Almira in this bathroom?',
 "Does this animal resemble man's best friend?",
 'What sport is being played?',
 'Is the sink gray?',
 'Why type of flowers are in the vase?',
 'Are they playing at a park?',
 'What is the man doing?',
 'Of what material is the shower curtain?',
 "What company's name is seen?",
 'Where are they all going?',
 'Is it a hot day?',
 'In what type of establishment is this picture most likely taken?',
 'Do these foods appear healthy?',
 'How many lamps are there?',
 'How many sets of giraffes have their necks crossed?',
 'How many legs are there?',
 'What is this type of bike?',
 'What are they wearing on their heads?',
 'How many glass objects are on the windowsill?',
 'What is the person holding in their hand?',
 'What is the biggest thing on the road?',
 'What color is the animal in the sink?',
 'Is he wet mostly from water or mostly from alcohol?',
 'Is the cat sad?',
 "Do the man's cap and shorts match?",
 'Is she wearing a skirt or shorts?',
 'Does everyone have a shirt on?',
 'Have these sheep been sheared?',
 'Are here buildings?',
 'Is this a dairy farm?',
 'What are the people on the scene wearing?',
 'What does the sign say at the bottom?',
 'Is this a collage?',
 'Where is the baby standing?',
 'What time is it?',
 'Who does the sandwich belong to?',
 'Is this picture from a zoo?',
 'What number is on the clock?',
 'Do the power lines go through the Magnolia tree?',
 'How many stickers on the computer?',
 'Is the cat overweight?',
 'What kind of bread is shown?',
 'Is he going into the ocean?',
 'Are the computers on?',
 'What brand is the board?',
 'How many kites are in the air?',
 'What are the cows walking on?',
 'What is the man standing on?',
 'Is this a tall animal?',
 'How many people are cutting this cake?',
 "What colors are on this man's shirt?",
 'Which man is speaking?',
 'How many people are on the beach?',
 'What color is the elephant in the right hand picture?',
 'Would you expect to find this animal at the beach?',
 'What is flying in the sky?',
 'Does near the door need painted?',
 'How many stitch lines are on the inseam of these jeans?',
 'What color is this fruit?',
 'How many slices of pizza are there?',
 "What color is the woman's jacket?",
 'Can you see people in the picture?',
 'How many phones are shown?',
 'What was the name of the company?',
 'What country is this?',
 'Is the graffiti considered artwork?',
 'What is the man on the right doing?',
 'What color is the court?',
 'Does this guys headband match his shirt?',
 'Is he in the air?',
 'Is this type of bird currently endangered?',
 'Has the woman fallen?',
 'What color is the nearest streamer?',
 'What date was the photo taken?',
 'Are there any vegetables on the plate with the sandwich?',
 'What kind of paper is on the floor?',
 "What country's flag can be seen in the background?",
 'How many dogs are looking in front of them?',
 'What type of room is this?',
 'Is this a fun sport?',
 'Does the fire hydrant need to be painted?',
 'How many electronic devices are in the photo?',
 'What is advertised on the back of the truck?',
 'What geographical features are located in the background of the picture?',
 'Which animal is more appropriate for a child?',
 'Is the light on in the refrigerator?',
 'What are group does this female belong in?',
 'What is the man doing?',
 'How old is the person receiving this?',
 'Is the man wearing a hat?',
 'What is in the reflection?',
 'How many animals are there?',
 'What herb is sprinkled on the pizza slice?',
 'How many people are pictured?',
 'Does the horse do this because it wants to?',
 "What is the letter above the people's head?",
 'Is there a car in the background?',
 'What does the white street sign say?',
 'What holiday are they celebrating?',
 'What is on top of table in the vase?',
 'What is the color of the side mirror?',
 'What is the color of the elephant?',
 'Who many donuts have they eaten?',
 'Is everybody in this picture a girl?',
 'Is the cup half full?',
 'What type of wall is in the background?',
 'Which horse is the youngest?',
 'How many pepperonis are on this pizza?',
 'Is the kite multi colored?',
 'Are they celebrating?',
 'What time is this?',
 "What's in the bag?",
 "Is there a face painting on the person's finger in the picture?",
 'How many zebras in the photo?',
 'What is she eating?',
 'What can you see out the window?',
 'Is the person very old?',
 'What is emitting light?',
 'Are the numbers on the clock Roman numerals??',
 'How many people are in the picture?',
 'Is this child feeding the elephant?',
 'How many people are standing?',
 'What is colorful on the ramp?',
 'What is on the plate?',
 'When was the picture taken?',
 'Is the keyboard wireless?',
 'Is the woman smiling?']