In [1]:
import math
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
adapt_span_params = {'adapt_span_enabled': True, 'attn_span': 1024, 'adapt_span_loss': 0, 'adapt_span_ramp': 32, 'adapt_span_init': 0,
                     'adapt_span_cache': False, 'nb_heads': 12, 'bs':128}

In [3]:
num_attention_heads = 12
attention_head_size = 768//12
all_head_size = 12*64
hidden_size = 768

query = nn.Linear(hidden_size, all_head_size) # 768,768
key = nn.Linear(hidden_size, all_head_size)
value = nn.Linear(hidden_size, all_head_size)

In [4]:
hidden_states = torch.rand(128,36,768)
context = torch.rand(128,20,768)
attention_mask = torch.rand(128,1,1,20)

In [5]:
q = query(hidden_states)
print(q.shape)
k = key(context)
print(k.shape)
v = key(context)
print(v.shape)

torch.Size([128, 36, 768])
torch.Size([128, 20, 768])
torch.Size([128, 20, 768])


In [6]:
def _skew(X, pad_value):
    """shift every row 1 step to right"""
    # X = B x M x L
    B, M, L = X.size()
    X = F.pad(X, (0, M + 1), value=pad_value)  # B x M x (L+M+1)
    X = X.view(B, -1)  # B x ML+MM+M
    X = X[:, :-M]  # B x ML+MM
    X = X.view(B, M, M + L)  # B x M x L+M
    return X

In [7]:
def _unskew(X):
    """reverse _skew operation"""
    # X = B x M x L+M
    B, M, L = X.size()
    L -= M
    X = X.view(B, -1)  # B x ML+MM
    X = F.pad(X, (0, M))  # B x ML+MM+M
    X = X.view(B, M, M + L + 1)  # B x M x L+M+1
    X = X[:, :, :L]  # B x M x L
    return X

In [11]:
max_size = adapt_span_params['attn_span']    # [attn_span]
ramp_size = adapt_span_params['adapt_span_ramp']
shape = (128,adapt_span_params['bs'], 1,1)
init_val = adapt_span_params['adapt_span_init']
current_val = nn.init.uniform_(nn.Parameter(torch.zeros(*shape) + init_val)) # [bs,nb_heads,1,1]
mask_template = torch.linspace(1 - max_size, 0, steps=max_size)

In [25]:
adaptive_span = AdaptiveSpan(**adapt_span_params)
attn_span = adapt_span_params['attn_span']
k_pe = nn.Parameter(torch.randn(1, hidden_size // num_attention_heads, attn_span))

k, v, k_pe = adaptive_span.trim_memory(q,k,v,k_pe)
print(k.shape),print(v.shape), print(k_pe.shape)

torch.Size([128, 1060, 768])
torch.Size([128, 1060, 768])
torch.Size([1, 64, 1024])


(None, None, None)

In [27]:
q = transpose_for_scores(q) 
k = transpose_for_scores(k)
v = transpose_for_scores(v) 
print(q.shape, k.shape, v.shape)

torch.Size([128, 12, 36, 64]) torch.Size([128, 12, 1060, 64]) torch.Size([128, 12, 1060, 64])


In [28]:
k.transpose(-1,-2).shape

torch.Size([128, 12, 64, 1060])

In [29]:
attention_cont = torch.matmul(q, k.transpose(-1, -2))
print(attention_cont.shape)

torch.Size([128, 12, 36, 1060])


In [30]:
d0,d1,d2,d3 = attention_cont.size()

In [31]:
attention_cont = torch.reshape(attention_cont, (d0*d1,d2,d3))

In [32]:
attention_cont = _unskew(attention_cont)
print(attention_cont.shape)

torch.Size([1536, 36, 1024])


In [33]:
attention_cont = torch.reshape(attention_cont, (d0,d1,d2,-1))
print(attention_cont.shape)

torch.Size([128, 12, 36, 1024])


In [34]:
attention_pos = torch.matmul(q, k_pe)
attention_pos.shape

torch.Size([128, 12, 36, 1024])

In [35]:
attention_scores = attention_cont+attention_pos
attention_scores = attention_scores/math.sqrt(hidden_size)

In [36]:
attention_mask = nn.Linear(20,1024)(attention_mask)

In [37]:
attention_scores = attention_scores + attention_mask

In [38]:
attention_probs = nn.Softmax(dim=-1)(attention_scores) # [128, 12, 20, 52]
#print(attention_probs.shape)
attention_probs = adaptive_span(attention_probs)
attention_probs.shape

torch.Size([128, 12, 36, 1024])