In [1]:
word = ["Hello", "there", ".", "I", "am", "fine", "."]
position = [1, 2, 3, 4, 5, 6, 7]
sentence = [1, 1, 1, 2, 2, 2, 2]

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import time

from mem_transformer import RelPartialLearnableMultiHeadAttn

In [77]:
def einsum_custom(expression, mats, opt_level="O2"):
    if opt_level == "O1":
        return torch.einsum(expression, tuple(mat.float() for mat in mats))
    
    return torch.einsum(expression, mats)

class LAHMRelPartialLearnableMultiHeadAttn(nn.Module):
    def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
                 tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False):
        super(LAHMRelPartialLearnableMultiHeadAttn, self).__init__()

        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.mem_len = mem_len
        self.dropout = dropout

        self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
        self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
        
        self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)

        self.drop = nn.Dropout(dropout)
        self.dropatt = nn.Dropout(dropatt)
        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)

        self.layer_norm = nn.LayerNorm(d_model)

        self.scale = 1 / (d_head ** 0.5)

        self.pre_lnorm = pre_lnorm
    
    def _rel_shift(self, x, zero_triu=False): # see Transformer-xl paper appendix B
        zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
                               device=x.device, dtype=x.dtype)
        x_padded = torch.cat([zero_pad, x], dim=1)

        x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])

        x = x_padded[1:].view_as(x)

        if zero_triu:
            ones = torch.ones((x.size(0), x.size(1)))
            x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]

        return x
    
    def _rel_shift_future(self, x, zero_triu=False): # LAHM rel shift
        zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
                                device=x.device, dtype=x.dtype)
        x_padded = torch.cat([x, zero_pad], dim=1)

        x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])

        x = x_padded[:-1].view_as(x)

        if zero_triu:
            ones = torch.ones((x.size(0), x.size(1)))
            x = x * torch.triu(ones, x.size(1) - x.size(0))[:,:,None,None]

        return x

    def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None, recur_mems=None, denom=None):
        qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)

        if mems is not None:
            cat = torch.cat([mems, w], 0)
            if self.pre_lnorm:
                w_head_q = self.q_net(self.layer_norm(cat))
                w_head_kv = self.kv_net(self.layer_norm(cat))
            else:
                w_head_q = self.q_net(cat)
                w_head_kv = self.kv_net(cat)
            r_head_k = self.r_net(r)

            w_head_k, w_head_v = torch.chunk(w_head_kv, 2, dim=-1)
            
            klen = w_head_k.size(0)
            
            # Be careful!
            if w_head_k.size(0) > qlen:
                mem_head_q = w_head_q[:klen - qlen]
                mem_head_k = w_head_k[(- 2 * qlen + 1) : (- qlen + 1)]
                mem_head_v = w_head_v[(- 2 * qlen + 1) : (- qlen + 1)]

            
            w_head_q = w_head_q[-qlen:]

        else:
            if self.pre_lnorm:
                w_head_q = self.q_net(self.layer_norm(w))
                w_head_kv = self.kv_net(self.layer_norm(w))
            else:
                w_head_q = self.q_net(w)
                w_head_kv = self.kv_net(w)
            r_head_k = self.r_net(r)

            w_head_k, w_head_v = torch.chunk(w_head_kv, 2, dim=-1)

            klen = w_head_k.size(0)

        w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)           # klen x bsz x n_head x d_head
        w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head
        w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)           # qlen x bsz x n_head x d_head

        
        r_head_k_causal = r_head_k[:klen]
        r_head_k_ahead = r_head_k[klen:]

        r_head_k_causal = r_head_k_causal.view(klen, self.n_head, self.d_head)                # klen x n_head x d_head

        #### compute attention score
        rw_head_q = w_head_q + r_w_bias                                         # qlen x bsz x n_head x d_head
        AC = einsum_custom('ibnd,jbnd->ijbn', (rw_head_q, w_head_k))             # qlen x klen x bsz x n_head
        
        
        rr_head_q = w_head_q + r_r_bias
        BD = einsum_custom('ibnd,jnd->ijbn', (rr_head_q, r_head_k_causal))              # qlen x klen x bsz x n_head
        BD = self._rel_shift(BD)

        # [qlen x klen x bsz x n_head]
        attn_score = AC + BD
        attn_score.mul_(self.scale)
        


        #### compute attention probability
        if attn_mask is not None and attn_mask.any().item():
            if attn_mask.dim() == 2:
                attn_score = attn_score.float().masked_fill(
                    attn_mask[None,:,:,None].bool(), -float('inf')).type_as(attn_score)
            elif attn_mask.dim() == 3:
                attn_score = attn_score.float().masked_fill(
                    attn_mask[:,:,:,None].bool(), -float('inf')).type_as(attn_score)
                
        # [qlen x klen x bsz x n_head]
        attn_prob = F.softmax(attn_score, dim=1)
        attn_prob = self.dropatt(attn_prob)

        #### compute attention vector
        attn_vec = einsum_custom('ijbn,jbnd->ibnd', (attn_prob, w_head_v))

        # [qlen x bsz x n_head x d_head]
        _attn_vec = attn_vec.contiguous().view(
            attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)

        ##### linear projection
        attn_out = self.o_net(_attn_vec)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            ##### residual connection
            output = w + attn_out
        else:
            ##### residual connection + layer normalization
            output = self.layer_norm(w + attn_out)


        #### look-ahead memory
        if w_head_k.size(0) > qlen:
            mem_head_q = mem_head_q.view(klen - qlen, bsz, self.n_head, self.d_head)   # klen - qlen, bsz, n_head, d_head
            mem_head_k = mem_head_k.view(qlen, bsz, self.n_head, self.d_head)          # qlen, bsz, n_head, d_head
            mem_head_v = mem_head_v.view(qlen, bsz, self.n_head, self.d_head)          # qlen, bsz, n_head, d_head

            AC_mem = einsum_custom('ibnd,jbnd->ijbn', (mem_head_q + r_w_bias, mem_head_k))
            
            r_head_k_ahead = r_head_k_ahead.view(klen - qlen, self.n_head, self.d_head) # klen - qlen, bsz, n_head, d_head
            
            #print(mem_head_q.size())
            #print(r_head_k_ahead.size())
            
            BD_mem = einsum_custom('ibnd,jnd->ijbn', (mem_head_q + r_r_bias, r_head_k_ahead)) # klen - qlen, klen - qlen, bsz, d_head

            assert(BD_mem.size(0) == BD_mem.size(1))
            #print(BD_mem.size())
            BD_mem = self._rel_shift_future(BD_mem)[:, -qlen:].contiguous() # klen - qlen, qlen, bsz, d_head

            attn_score_mem = AC_mem + BD_mem
            attn_score_mem.mul_(self.scale)

            attn_mask_mem = torch.ones(klen - qlen, klen - qlen, dtype=attn_mask.dtype, device=attn_mask.device)
            attn_mask_mem = torch.triu(attn_mask_mem)[:, -qlen:].ne(1).bool()[:, :, None, None]
            #print(attn_mask_mem)
            attn_score_mem = attn_score_mem.float().masked_fill_(attn_mask_mem, -float('inf')).type_as(attn_score_mem)
            #print(attn_score_mem)
            
            
            attn_prob_mem = F.softmax(attn_score_mem, dim=1)
            attn_prob_mem = self.dropatt(attn_prob_mem)

            attn_vec_mem = einsum_custom('ijbn,jbnd->ibnd', (attn_prob_mem, mem_head_v)) # klen - qlen, bsz, n_head, d_head
            
            #### memory interpolation
            mem_denom = attn_score_mem.exp().sum(dim=1)
            
            inter_mems = (recur_mems * denom.unsqueeze(-1) + attn_vec_mem * mem_denom.unsqueeze(-1)) / (denom.unsqueeze(-1) + mem_denom.unsqueeze(-1))
            
            
            #### layer norm (non-linear module)
            
            _inter_mems = inter_mems.contiguous().view(
                attn_vec_mem.size(0), attn_vec_mem.size(1), self.n_head * self.d_head)
            
            attn_out_mem = self.o_net(_inter_mems)
            attn_out_mem = self.drop(attn_out_mem)

            if self.pre_lnorm:
                ##### residual connection
                output_mem = mems + attn_out_mem
            else:
                ##### residual connection + layer normalization
                output_mem = self.layer_norm(mems + attn_out_mem)

        else:
            output_mem = torch.empty(0, dtype=attn_score.dtype, device=attn_score.device)
            mem_denom = torch.empty(0, dtype=attn_score.dtype, device=attn_score.device)
            inter_mems = torch.empty(0, dtype=attn_score.dtype, device=attn_score.device)
        
        
        #### update denom 
        cur_denom = attn_score.exp().sum(dim=1) # qlen
        #print(attn_score)
        
        # TODO: should consider mem_len > qlen
        prev_denom = (denom + mem_denom) # mlen
        new_denom = torch.cat([prev_denom, cur_denom], dim=0) # mlen + qlen
        
        beg_idx = max(0, klen - self.mem_len)
        
        # TODO: Are we going to detach it ?
        new_denom = new_denom[beg_idx:].detach()
        
        
        #### update recurrent memory
        new_recur_mems = torch.cat([inter_mems, attn_vec], dim=0)[beg_idx:].detach()
        #print(attn_vec.size())
        #print(new_recur_mems.size())

        # tgt output
        # layer memory output
        # recurrent memory output
        # recurrent denom
        return output, output_mem, new_recur_mems, new_denom

In [79]:
# first step without memory
mem_len = 150
device = "cuda:1"
mlen = 0
qlen = 150
bsz = 128
dim = 410
d_head = 41
n_head = 10
klen = mlen + qlen

torch.manual_seed(42)

r_w_bias = nn.Parameter(torch.randn(n_head, dim // n_head)).to(device)
r_r_bias = nn.Parameter(torch.randn(n_head, dim // n_head)).to(device)
dec_attn_mask = torch.triu(
    torch.ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None].to(device)
#print(dec_attn_mask)

r = torch.arange(klen - 1, - mlen - 1, -1.0).unsqueeze(1).expand(-1, dim).to(device)

w = torch.randn(qlen, bsz, dim).to(device)

mems = torch.empty(0).to(device)
denom = torch.empty(0).to(device)
recur_mems = torch.empty(0).to(device)

net = LAHMRelPartialLearnableMultiHeadAttn(n_head, dim, d_head, 0.1, tgt_len=qlen, mem_len=mem_len).to(device)


start_time_lahm = time.time()
net.eval()
hidden, next_layer_mems, recur_mems, recur_denom = net(w, r, r_w_bias, r_r_bias, dec_attn_mask, mems, recur_mems, denom)

print("step 1")
print("LAHM: {} s".format(time.time() - start_time_lahm))

print("step 2")


mlen = qlen
klen = 2 * qlen
r = torch.arange(klen - 1, - qlen - 1, -1.0).unsqueeze(1).expand(-1, dim).to(device)
dec_attn_mask = torch.triu(
    torch.ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None].to(device)
mems = torch.randn(qlen, bsz, dim).to(device)

start_time_lahm_2 = time.time()
net.eval()
hidden, next_layer_mems, recur_mems, recur_denom = net(w, r, r_w_bias, r_r_bias, dec_attn_mask, mems, recur_mems, recur_denom)
#print(hidden.size())
print("LAHM: {} s".format(time.time() - start_time_lahm_2))






step 1
LAHM: 0.00693202018737793 s
step 2
torch.Size([300, 128, 410])
torch.Size([150, 128, 410])
LAHM: 0.01898646354675293 s


In [76]:
mem_len = 150
device = "cuda:1"
mlen = 0
qlen = 150
bsz = 128
dim = 410
d_head = 41
n_head = 10
klen = mlen + qlen

torch.manual_seed(42)

r_w_bias = nn.Parameter(torch.randn(n_head, dim // n_head)).to(device)
r_r_bias = nn.Parameter(torch.randn(n_head, dim // n_head)).to(device)
dec_attn_mask = torch.triu(
    torch.ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None].to(device)
#print(dec_attn_mask)

r = torch.arange(klen - 1, - qlen - 1, -1.0).unsqueeze(1).expand(-1, dim).to(device)

w = torch.randn(qlen, bsz, dim).to(device)

mems = torch.empty(0).to(device)
denom = torch.empty(0).to(device)
recur_mems = torch.empty(0).to(device)

net2 = RelPartialLearnableMultiHeadAttn(n_head, dim, d_head, 0.1, tgt_len=qlen, mem_len=mem_len).to(device)



start_time_xl = time.time()
net2.eval()
hid = net2(w, r[:klen], r_w_bias, r_r_bias, dec_attn_mask, torch.empty(0).to(device))
print("step 1")
print("xl: {} s".format(time.time() - start_time_xl))

mlen = qlen
klen = 2 * qlen
r = torch.arange(klen - 1, - qlen - 1, -1.0).unsqueeze(1).expand(-1, dim).to(device)
dec_attn_mask = torch.triu(
    torch.ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None].to(device)
mems = torch.randn(qlen, bsz, dim).to(device)

start_time_xl_2 = time.time()
net2.eval()
hid = net2(w, r[:klen], r_w_bias, r_r_bias, dec_attn_mask, mems)
print("step 2")
print("xl: {} s".format(time.time() - start_time_xl_2))

step 1
xl: 0.0069615840911865234 s
step 2
xl: 0.017897844314575195 s


In [136]:
print(r)

tensor([[ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.],
        [ 4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.],
        [ 3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.],
        [ 2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.],
        [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [-1., -1., -1., -1., -1., -1., -1., -1.],
        [-2., -2., -2., -2., -2., -2., -2., -2.]])


In [109]:
recur_mems.size()

torch.Size([3, 2, 2, 4])

In [164]:
a = torch.randn(3,4)
b = torch.empty(0)
torch.cat([a,b],0)

tensor([[-0.1894, -0.6018, -0.3084,  0.2983],
        [ 0.0254,  1.0685, -0.2839,  0.2252],
        [-0.9691,  0.1838, -0.9397, -0.2222]])

In [112]:
recur_denom.size()

torch.Size([3, 2, 2])