In [1]:
import math
import torch
from torch import nn
import torch.nn.functional as F
from transformers.modeling_bert import BertLayerNorm
from transformers import BertTokenizer
from lxrt.modeling import BertPreTrainedModel

I0113 23:23:36.288935 139673391356800 file_utils.py:35] PyTorch version 1.3.1+cpu available.
  from ._conv import register_converters as _register_converters


In [171]:
def _skew(X, pad_value):
    """shift every row 1 step to right"""
    # X = B x M x L
    B, M, L = X.size()
    X = F.pad(X, (0, M + 1), value=pad_value)  # B x M x (L+M+1)
    X = X.view(B, -1)  # B x ML+MM+M
    X = X[:, :-M]  # B x ML+MM
    X = X.view(B, M, M + L)  # B x M x L+M
    return X


def _unskew(X):
    """reverse _skew operation"""
    # X = B x M x L+M
    B, M, L = X.size()
    L -= M
    X = X.view(B, -1)  # B x ML+MM
    X = F.pad(X, (0, M))  # B x ML+MM+M
    X = X.view(B, M, M + L + 1)  # B x M x L+M+1
    X = X[:, :, :L]  # B x M x L
    return X

class AdaptiveMask(nn.Module):
    """Soft masking function for adaptive size.
    It masks out the last K values of an input. The masking value
    goes from 1 to 0 gradually, so K can be learned with
    back-propagation.
    Args:
        max_size: maximum size (i.e. input dimension)
        ramp_size: size of the ramp going from 0 to 1
        init_val: initial size proportion not to be masked out
        shape: learn multiple sizes independent of each other
    """

    def __init__(self, max_size, ramp_size, init_val=0, shape=(1,)):
        nn.Module.__init__(self)
        self._max_size = max_size                                       # 32
        self._ramp_size = ramp_size
        self.current_val = nn.Parameter(torch.zeros(*shape) + init_val) # [12,1,1]
        mask_template = torch.linspace(1 - max_size, 0, steps=max_size) # 32
        self.register_buffer('mask_template', mask_template)

    def forward(self, x):
        mask = self.mask_template + self.current_val * self._max_size # [12,1,32]
        mask = mask / self._ramp_size + 1                             # [12,1,32]
        mask = mask.clamp(0, 1)
        if x.size(-1) < self._max_size:
            # the input could have been trimmed beforehand to save computation
            mask = mask[:, :, -x.size(-1):]
        #print(x.shape, mask.shape)
        x = x * mask # [128, 12, 20, 32], [12, 1, 32]
        #print(x.shape)
        return x

    def get_current_max_size(self, include_ramp=True):
        current_size = math.ceil(self.current_val.max().item() * self._max_size)
        if include_ramp:
            current_size += self._ramp_size
        current_size = max(0, min(self._max_size, current_size))
        return current_size

    def get_current_avg_size(self, include_ramp=True):
        current_size = math.ceil(self.current_val.mean().item() * self._max_size)
        if include_ramp:
            current_size += self._ramp_size
        current_size = max(0, min(self._max_size, current_size))
        return current_size

    def clamp_param(self):
        """this need to be called after each update"""
        self.current_val.data.clamp_(0, 1)


class AdaptiveSpan(nn.Module):
    """Adaptive attention span for Transformerself.
    This module learns an attention span length from data for each
    self-attention head.
    Args:
        attn_span: maximum attention span
        adapt_span_loss: loss coefficient for the span length
        adapt_span_ramp: length of the masking ramp
        adapt_span_init: initial size ratio
        adapt_span_cache: adapt cache size to reduce memory usage
    """
    def __init__(self, attn_span, adapt_span_loss, adapt_span_ramp,
                 adapt_span_init, adapt_span_cache, nb_heads, **kargs):
        nn.Module.__init__(self)
        self._adapt_cache = adapt_span_cache
        self._max_span = attn_span
        self._loss_coeff = adapt_span_loss
        self._nb_heads = nb_heads
        self._mask = AdaptiveMask(max_size=self._max_span,
                                 ramp_size=adapt_span_ramp,
                                 init_val=adapt_span_init,
                                 shape=(128,nb_heads,1, 1)) 
        # TO-DO: batch_size to be dynamically controlled
        
        self.attn_linear = nn.Linear(attn_span+20,attn_span)
        
    def forward(self, attn):
        """mask attention with the right span"""
        # batch and head dimensions are merged together, so separate them first
        B = attn.size(0) # batch size
        M = attn.size(1) # block size
        #attn = attn.reshape(B // self._nb_heads, self._nb_heads, M, -1)
        #################   Project into same embedding space ##################
        attn = self.attn_linear(attn)
        ########################################################################
        attn = self._mask(attn)
        attn = attn / (attn.sum(-1, keepdim=True) + 1e-8)  # normalize so sum is 1

        #attn = attn.view(B, M, -1)
        return attn

    def get_trim_len(self):
        """how much of memory can be trimmed to reduce computation"""
        L = self._max_span
        trim_len = min(L - 1, L - self._mask.get_current_max_size())
        # too fine granularity might be bad for the memory management
        trim_len = math.floor(trim_len / 64) * 64
        return trim_len

    def trim_memory(self, query, key, value, key_pe):
        """trim out unnecessary memory beforehand to reduce computation"""
        trim_len = self.get_trim_len()
        cache_size = key.size(1) - query.size(1)
        trim_len_cache = trim_len - (self._max_span - cache_size)
        if trim_len_cache > 0:
            key = key[:, trim_len_cache:, :]
            value = value[:, trim_len_cache:, :]
        elif trim_len_cache < 0:
            # cache is too short! this happens when validation resumes
            # after a lot of updates.
            key = F.pad(key, [0, 0, -trim_len_cache, 0])
            value = F.pad(value, [0, 0, -trim_len_cache, 0])
        if trim_len > 0:
            if key_pe is not None:
                key_pe = key_pe[:, :, trim_len:]
        return key, value, key_pe

    def get_cache_size(self):
        """determine how long the cache should be"""
        if self._adapt_cache:
            trim_len = self.get_trim_len()
            # give a buffer of 64 steps since a span might increase
            # in future updates
            return min(self._max_span, self._max_span - trim_len + 64)
        else:
            return self._max_span

    def get_loss(self):
        """a loss term for regularizing the span length"""
        return self._loss_coeff * self._max_span * self._mask.current_val.mean()

    def get_current_max_span(self):
        return self._mask.get_current_max_size()

    def get_current_avg_span(self):
        return self._mask.get_current_avg_size()

    def clamp_param(self):
        self._mask.clamp_param()

In [145]:
adapt_span_params = {'adapt_span_enabled': True, 'attn_span': 32, 'adapt_span_loss': 0, 'adapt_span_ramp': 32, 'adapt_span_init': 0,
                     'adapt_span_cache': False, 'nb_heads': 12}

In [146]:
adaptive_span = AdaptiveSpan(**adapt_span_params)
query, key, value = torch.rand(128,20,768), torch.rand(128,36,768), torch.rand(128,36,768)
key_layer, value_layer, key_pe = adaptive_span.trim_memory(query,key,value,nn.Parameter(
            torch.randn(1, 768 // 12, 32)))
key_layer.shape, value_layer.shape, key_pe.shape
att = adaptive_span(torch.rand(128,12,20,52))

torch.Size([128, 12, 20, 32]) torch.Size([128, 12, 1, 32])
torch.Size([128, 12, 20, 32])


In [3]:
from lxrt.entry import InputFeatures,convert_sents_to_features,set_visual_config
from lxrt.modeling import VISUAL_CONFIG

In [4]:
from transformers import BertConfig
bert_config = BertConfig()

In [5]:
class GeLU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,x):
        return F.gelu(x)

In [6]:
class Args():
    def __init__(self,l_layers,x_layers,r_layers):
        self.llayers = l_layers
        self.xlayers = x_layers
        self.rlayers = r_layers
        self.from_scratch=False
args = Args(9,5,5)
MAX_VQA_LENGTH = 20

In [7]:
## BertEmbeddings
class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings.
    """
    def __init__(self, config):
        super(BertEmbeddings, self).__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, padding_idx=0)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size, padding_idx=0)

        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids, token_type_ids=None):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        words_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        embeddings = words_embeddings + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings
    
from transformers import BertConfig
bert_embeddings = BertEmbeddings(BertConfig())
output = bert_embeddings(input_ids = torch.rand(128,20).long(),
                         token_type_ids = torch.rand(128,20).long())
output.shape

torch.Size([128, 20, 768])

In [211]:
## BertAttention

class BertAttention(nn.Module):
    """
    from transformers import BertConfig
    
    bert_att = BertAttention(BertConfig())
    context_output = bert_att(hidden_states = torch.rand(128,20,768),
                          context = torch.rand(128,36,768),
                          attention_mask = None)
    context_output.shape # [128, 20, 768]

    """
    def __init__(self, config, ctx_dim=None, adapt_span_params=None):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads))
        self.num_attention_heads = config.num_attention_heads # 12
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 768/12
        self.all_head_size = self.num_attention_heads * self.attention_head_size # 12*64

        # visual_dim = 2048
        if ctx_dim is None:
            ctx_dim =config.hidden_size
        self.query = nn.Linear(config.hidden_size, self.all_head_size) # 768x768
        self.key = nn.Linear(ctx_dim, self.all_head_size) # 768x768
        self.value = nn.Linear(ctx_dim, self.all_head_size) # 768x768
        
        attn_span = adapt_span_params['attn_span']
        
        self.key_pe = nn.Parameter(
            torch.randn(1, config.hidden_size // self.num_attention_heads, attn_span))
        
        self.key_pe_lin = nn.Linear(attn_span,attn_span+20)
        self.val_layer_project = nn.Linear(attn_span+20,attn_span)
        
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        
        self.adapt_span_enabled = adapt_span_params['adapt_span_enabled']
        
        if self.adapt_span_enabled:
            self.adaptive_span = AdaptiveSpan(**adapt_span_params)
        

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states, context, attention_mask=None):
        
        mixed_query_layer = self.query(hidden_states) # [128,20,768]
        mixed_key_layer = self.key(context) # [128,36,768]
        mixed_value_layer = self.value(context) #[128,36,768]
        
        
        
        if self.adapt_span_enabled:
            mixed_key_layer, mixed_value_layer, key_pe = self.adaptive_span.trim_memory(mixed_query_layer,
                                                                            mixed_key_layer,
                                                                            mixed_value_layer,
                                                                            self.key_pe)
            
        # mixed_key_layer -> [128, 52, 768],
        # mixed_value_layer -> [128, 52, 768]
        # key_pe -> [1, 64, 32])
    
        query_layer = self.transpose_for_scores(mixed_query_layer) # [128, 12, 20, 64]
        key_layer = self.transpose_for_scores(mixed_key_layer) # [128, 12, 52, 64]
        value_layer = self.transpose_for_scores(mixed_value_layer) # [128, 12, 52, 64]
        
        
        # Take the dot product between "query" and "key" to get the raw attention scores.
        # print('key_layer_transpose: ', key_layer.transpose(-1,-2).shape) : [128, 12, 64, 36]
    
        attention_cont = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [128, 12, 20, 52]
        
        #attention_cont = _unskew(attention_cont)
        
        attention_pos = torch.matmul(query_layer, key_pe) # [128,12,20,32]
        
############     Project into same embedding space ###############################     
        attention_pos = self.key_pe_lin(attention_pos) # [128,12,20,52]
###################################################################################    

        attention_scores = attention_cont + attention_pos
        attention_scores = attention_scores / math.sqrt(self.attention_head_size) # [128, 12, 20, 52]
        # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
        if attention_mask is not None:
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores) # [128, 12, 20, 52]
        #print(attention_probs.shape)
        if self.adapt_span_enabled:
            attention_probs = self.adaptive_span(attention_probs) # [128,12,640]
        
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        
        attention_probs = self.dropout(attention_probs) # [128, 12, 20, 36]
        
        #attention_probs = _skew(attention_probs,0)
        value_layer = value_layer.permute(0,1,3,2) # [128,12,64,52]
        value_layer = self.val_layer_project(value_layer).permute(0,1,3,2) # [128,12,32,64]
        
        #  Earlier: [128, 12, 20, 32] x [128, 12, 32, 64]
        print(attention_probs.shape, value_layer.shape)
        
        context_layer = torch.matmul(attention_probs, value_layer) # [128, 12, 20, 64]
        
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # [128, 20, 12, 64]
        
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) # [128, 20, 768]
        
        context_layer = context_layer.view(*new_context_layer_shape) # [128, 20, 768]
        return context_layer

In [212]:
from transformers import BertConfig
bert_att = BertAttention(BertConfig(),adapt_span_params=adapt_span_params)
context_output = bert_att(hidden_states = torch.rand(128,20,768),
                          context = torch.rand(128,36,768),
                          attention_mask = None)
context_output.shape

torch.Size([128, 12, 20, 32]) torch.Size([128, 12, 32, 64])


torch.Size([128, 20, 768])

In [215]:
## BertAttOutput

class BertAttOutput(nn.Module):
    """
    from transformers import BertConfig
    bert_att_output = BertAttOutput(BertConfig())
    output = bert_att_output(torch.rand(128,20,768),torch.rand(128,20,768))
    output.shape [128,20,768]

    """
    def __init__(self, config):
        super(BertAttOutput, self).__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        return hidden_states

from transformers import BertConfig
bert_att_output = BertAttOutput(BertConfig())
output = bert_att_output(torch.rand(128,20,768),torch.rand(128,20,768))
output.shape

torch.Size([128, 20, 768])

In [218]:
## BertCross Attention
class BertCrossattLayer(nn.Module):
    """
    from transformers import BertConfig
    
    bert_cross_att = BertCrossattLayer(BertConfig())
    output = bert_cross_att(input_tensor = torch.rand(128,20,768), 
                        ctx_tensor = torch.rand(128,36,768), 
                        ctx_att_mask = None)
                        
    output.shape [128,20,768]
    """
    def __init__(self, config,adapt_span_params):
        super().__init__()
        self.att = BertAttention(config,adapt_span_params=adapt_span_params)
        self.output = BertAttOutput(config)

    def forward(self, input_tensor, ctx_tensor, ctx_att_mask=None):
        output = self.att(input_tensor, ctx_tensor, ctx_att_mask) # [128,20,768]
        attention_output = self.output(output, input_tensor)
        return attention_output
    
from transformers import BertConfig
bert_cross_att = BertCrossattLayer(BertConfig(),adapt_span_params=adapt_span_params)
output = bert_cross_att(input_tensor = torch.rand(128,20,768), 
                        ctx_tensor = torch.rand(128,36,768), 
                        ctx_att_mask = None)
output.shape

torch.Size([128, 12, 20, 32]) torch.Size([128, 12, 32, 64])


torch.Size([128, 20, 768])

In [221]:
## BertSelfattLayer

class BertSelfattLayer(nn.Module):
    """
    bert_self_att_layer = BertSelfattLayer(bert_config)
    output = bert_self_att_layer(input_tensor = torch.rand(128,20,768),
                             attention_mask = torch.rand(128,1,1,20))
    output.shape [128, 20, 768]
    """
    def __init__(self, config,adapt_span_params=adapt_span_params):
        super(BertSelfattLayer, self).__init__()
        self.self = BertAttention(config,adapt_span_params=adapt_span_params)
        self.output = BertAttOutput(config)

    def forward(self, input_tensor, attention_mask):
        # Self attention attends to itself, thus keys and querys are the same (input_tensor).
        self_output = self.self(input_tensor, input_tensor, attention_mask)
        attention_output = self.output(self_output, input_tensor)
        return attention_output
    
bert_self_att_layer = BertSelfattLayer(bert_config,adapt_span_params=adapt_span_params)
output = bert_self_att_layer(input_tensor = torch.rand(128,20,768),
                             attention_mask = torch.rand(128,1,1,52))
output.shape

torch.Size([128, 12, 20, 32]) torch.Size([128, 12, 32, 64])


torch.Size([128, 20, 768])