In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def pad_sequence(sequences, require_padding_mask=False, require_lens=False,
                 batch_first=False):
    """List of sequences to padded sequences

    Args:
        sequences: List of sequences (N, D)
        require_padding_mask:

    Returns:
        (padded_sequence, padding_mask), where
           padded sequence has shape (N_max, B, D)
           padding_mask will be none if require_padding_mask is False
    """
    padded = nn.utils.rnn.pad_sequence(sequences, batch_first=batch_first)
    padding_mask = None
    padding_lens = None

    if require_padding_mask:
        B = len(sequences)
        seq_lens = list(map(len, sequences))
        padding_mask = torch.zeros((B, padded.shape[0]), dtype=torch.bool, device=padded.device)
        for i, l in enumerate(seq_lens):
            padding_mask[i, l:] = True

    if require_lens:
        padding_lens = [seq.shape[0] for seq in sequences]

    return padded, padding_mask, padding_lens

In [3]:
a_len = torch.rand((80,256))
b_len = torch.rand((100,256))

c_len = torch.rand((80,256))
d_len = torch.rand((120,256))

sequence1 = [a_len,b_len]
sequence2 = [c_len,d_len]
padded_seq1, mask_seq1, _ = pad_sequence(sequences=sequence1, require_padding_mask=True)
print(padded_seq1.shape)
print(mask_seq1.shape)
padded_seq2, mask_seq2, _ = pad_sequence(sequences=sequence2, require_padding_mask=True)
print(padded_seq2.shape)
print(mask_seq2.shape)

torch.Size([100, 2, 256])
torch.Size([2, 100])
torch.Size([120, 2, 256])
torch.Size([2, 120])


In [4]:

class ReAttention(nn.Module):
    
    def __init__(self, dim, num_heads=8, d_k=256, d_v=256):
        super().__init__()
        self.num_heads = num_heads
        # head_dim = dim // num_heads
        
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = 1.0/(num_heads**0.5)
        
        self.d_k = d_k
        self.d_v = d_v
        self.w_q = nn.Linear(dim, num_heads*d_k,bias=False)
        self.w_k = nn.Linear(dim, num_heads*d_k,bias=False)
        self.w_v = nn.Linear(dim, num_heads*d_v,bias=False)
        self.layer_norm = nn.LayerNorm(dim, eps=1e-6)

        self.reatten_matrix = nn.Conv2d(self.num_heads,self.num_heads, 1, 1)
        self.var_norm = nn.BatchNorm2d(self.num_heads)
        self.reatten_scale = self.scale

        self.attn_drop = nn.Dropout(0.1)
        self.proj = nn.Linear(num_heads*d_v, dim)
        self.proj_drop = nn.Dropout(0.9)

    
    def Myattention(self, q, k, v, attn_mask):

        B, num_head, Nt, E = q.shape
        q = q / math.sqrt(E)
        # (B, num_head, Nt, E) x (B, num_head, E, Ns) -> (B, num_head, Nt, Ns)
        attn = torch.matmul(q, k.transpose(-2, -1))
        if attn_mask is not None:
            attn += attn_mask
        attn = attn.softmax(dim=-1)       
        attn = self.attn_drop(attn)

        #  attn  (B, num_head, Nt, Ns)    
        attn = self.var_norm(self.reatten_matrix(attn))*self.reatten_scale
        
#         atten = atten.view(B*num_heads, atten.shape[2], atten.shape[3])
#         v = v.view(B*num_head, v.shape[2], v.shape[3])
        # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
        output = torch.matmul(attn, v)
        return output, attn    
    
    def forward(self, q, k, v, key_padding_mask):

        
        Nq,B,Dq = q.shape
        Nk,B,Dk = k.shape
        Nv,B,Dv = v.shape
        tgt_len, bsz, embed_dim = q.shape
        src_len = k.shape[0]
        bsz = key_padding_mask.shape[0]
        attn_mask = key_padding_mask.view(bsz,1,1,src_len).expand(-1, self.num_heads, -1, -1)


        new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
        new_attn_mask.masked_fill_(attn_mask, float("-inf"))
        attn_mask = new_attn_mask    

        
        newq = self.w_q(q).view(B,self.num_heads,Nq, self.d_k)
        newk = self.w_q(k).view(B,self.num_heads,Nk, self.d_k)
        newv = self.w_q(v).view(B,self.num_heads,Nv, self.d_v) 

        attn_output, attn_output_weights = self.Myattention(newq, newk, newv, attn_mask)

        attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, self.d_k*self.num_heads)

        attn_output = self.proj(attn_output)
        return attn_output




In [5]:
MyReAttention = ReAttention(dim=256)
output = MyReAttention(q=padded_seq1,k=padded_seq1,v=padded_seq1,key_padding_mask=mask_seq1)
print(output.shape)

MyReAttention1 = ReAttention(dim=256)
output1 = MyReAttention1(q=padded_seq2,k=padded_seq1,v=padded_seq1,key_padding_mask=mask_seq1)
print(output1.shape)

torch.Size([100, 2, 256])
torch.Size([120, 2, 256])
