In [1]:
import torch.nn as nn
import torch
import numpy as np
from torch.autograd import Variable
import math
import torch.nn.functional as F

#### 注意力计算公式   
$$A = Softmax(Q*K^T/\sqrt{d})*V$$

In [3]:
class ScaledDotProductAttention(nn.Module):
    #计算注意力
    def __init__(self,
                 attention_dropout=0.0):

        super(ScaledDotProductAttention,self).__init__()

        self.dropout = nn.Dropout(attention_dropout)
        self.softmax = nn.Softmax(dim = -1)


    def forward(self,q,k,v,scale=None,attn_mask = None):

        attention = torch.matmul(q,k.transpose(-2,-1)) # 计算 Q*K^T

        if scale:
            attention = attention * scale

        # mask attention. The attentions between the masked words and
        # other words are set to negative infinity
        if attn_mask is not None:
            attention = attention.masked_fill_(attn_mask,-np.inf) 
        # 这里掩码会把 Q*K^T里需要被掩盖的部分换成-inf 这样在softmax里该数值就变为零
        # 在Encoder里 需要掩盖住填充的0  在Decoder里除了掩盖住填充的0外 还要掩盖住后面的词

        attention = self.softmax(attention)
        attention = self.dropout(attention)
        context = torch.matmul(attention,v)

        return context

#### 多头注意力机制

In [4]:
class MultiHeadAttention(nn.Module):
    # compute multi heads attention
    # 多头注意力的本质是由多个Wq,Wk,Wv计算出多组 Q,K,V从而得到多个向量 
    # 这里实现的方式是 由一个大的Wq,Wk,Wv 计算出一组大的Q,K,V 再把这个Q,K,V分成若干个
    def __init__(self,
                 d_modl=512,
                 num_heads=8,
                 dropout=0.0):

        super(MultiHeadAttention,self).__init__()

        self.dim_per_head = d_modl // num_heads #计算每个头的维度
        self.num_heads = num_heads
        self.linear_k = nn.Linear(d_modl, d_modl)
        self.linear_v = nn.Linear(d_modl, d_modl)
        self.linear_q = nn.Linear(d_modl, d_modl)

        self.dot_product_attention = ScaledDotProductAttention(dropout)
        self.linear_final = nn.Linear(d_modl,d_modl)
        self.norm = nn.LayerNorm(d_modl)


    def forward(self, keys, values, queries, attn_mask=None):

        residual = queries
        batch_size = keys.size(0)
        #generate keys,values and queries from inputs
        keys = self.linear_k(keys) # 计算Wk * E(输入词向量) = K
        values = self.linear_v(values) # Wv * E  = V
        queries = self.linear_q(queries) #Wq *E =Q
        
        #以下做的就是将Q,K,V分别拆分成num_head个 q,k,v
        keys = keys.view(batch_size , -1, self.num_heads, self.dim_per_head).transpose(1,2) 
        values = values.view(batch_size, -1, self.num_heads, self.dim_per_head).transpose(1,2)
        queries = queries.view(batch_size, -1, self.num_heads, self.dim_per_head).transpose(1,2)

        if attn_mask is not None:
            attn_mask = attn_mask.unsqueeze(1).repeat(1,self.num_heads,1,1)

        scale = (keys.size(-1)) ** -0.5
        #计算注意力
        context = self.dot_product_attention(queries,keys,values,scale,attn_mask)
        
        #将多个头的输出向量拼接合并
        context = context.transpose(1,2).contiguous() \
                  .view(batch_size,-1,self.num_heads * self.dim_per_head)

        # layer normalization and residual network
        return self.norm(residual+self.linear_final(context)) # linear 将拼接够的多头 进行信息融合和映射回d维度


#### 位置编码   
$$PE_{(pos,2i)} = sin(\frac{pos}{1000^{2i/d_{model}}})$$   
$$PE_{(pos,2i+1)} = cos(\frac{pos}{1000^{2i/d_{model}}})$$

In [5]:
class PositionalEncoding(nn.Module):

    #compute position encoding

    def __init__(self,
                 d_model,
                 max_seq_len,
                 dropout=0.0):

        super(PositionalEncoding,self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_seq_len,d_model) #初始化位置向量
        position = torch.arange(0.,max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0.,d_model,2)*-(math.log(10000.0)/d_model)) #计算分母

        pe[:,0::2] = torch.sin(position * div_term) #计算位置编码向量里偶数位子的数值
        pe[:,1::2] = torch.cos(position * div_term) #计算位置编码里奇数位置的数值

        pe = pe.unsqueeze(0)
        self.register_buffer("pe",pe)


    def forward(self,x):

        x = x + Variable(self.pe[:,:x.size(1)],requires_grad=False)

        return self.dropout(x)


#### 前向+层归一

$$Out = Layernorm(x + W_2*ReLu(W_1+bias)+bias)$$

In [6]:
class PositionalWiseFeedForward(nn.Module):
   #前向传播+residual connection
    def __init__(self,
                 d_model=512,
                 ffn_dim=2048,
                 dropout=0.0):

        super(PositionalWiseFeedForward,self).__init__()

        self.w1 = nn.Linear(d_model,ffn_dim)
        self.w2 = nn.Linear(ffn_dim,d_model)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d_model)


    def forward(self,x):

        output = self.w2(F.relu(self.w1(x)))
        # layer normalization and residual network
        return self.norm(x+self.dropout(output))

In [7]:
class EncoderLayer(nn.Module):

    def __init__(self,
                 d_model = 512,
                 num_heads = 8,
                 ffn_dim = 2018,
                 dropout = 0.0):

        super(EncoderLayer,self).__init__()

        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionalWiseFeedForward(d_model, ffn_dim, dropout)

    def forward(self, x, attn_mask = None):

        context = self.attention(x,x,x,attn_mask)
        output = self.feed_forward(context)

        return output

In [8]:
class Encoder(nn.Module):

    def __init__(self,
                 vocab_size,
                 max_seq_len,
                 num_layers = 6,
                 d_model = 512,
                 num_heads = 8,
                 ffn_dim = 2048,
                 dropout = 0.0):

        super(Encoder,self).__init__()
        #以下代码是建立num_layer层 
        self.encoder_layers = nn.ModuleList(
                            [EncoderLayer(d_model,num_heads,ffn_dim,dropout) for _ in range(num_layers)])

        self.pos_embedding = PositionalEncoding(d_model, max_seq_len,dropout)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, seq_embedding):

        embedding = seq_embedding(x)
        output = self.pos_embedding(embedding)
        self_attention_mask = padding_mask(x,x)

        for encoder in self.encoder_layers:
            output = encoder(output,self_attention_mask)

        return self.norm(output)

In [9]:
class DecoderLayer(nn.Module):

    def __init__(self,
                 d_model,
                 num_heads = 8,
                 ffn_dim = 2048,
                 dropout = 0.0):

        super(DecoderLayer,self).__init__()

        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = PositionalWiseFeedForward(d_model, ffn_dim, dropout)


    def forward(self, dec_inputs, enc_outputs, self_attn_mask = None,context_attn_mask = None):

        dec_ouput  = self.attention(dec_inputs, dec_inputs, dec_inputs ,self_attn_mask)

        dec_ouput = self.attention(enc_outputs, enc_outputs,dec_ouput, context_attn_mask)

        dec_ouput = self.feed_forward(dec_ouput)

        return dec_ouput

class Decoder(nn.Module):

    def __init__(self,
                vocab_size,
                 max_seq_len,
                 device,
                 num_layers = 6,
                 d_model  = 512,
                 num_heads = 8,
                 ffn_dim = 2048,
                 dropout = 0.0,
                 ):

        super(Decoder,self).__init__()
        self.device = device
        self.num_layers = num_layers

        self.decoder_layers = nn.ModuleList(
            [DecoderLayer(d_model,num_heads,ffn_dim,dropout) for _ in range(num_layers)])

        self.seq_embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_embedding = PositionalEncoding(d_model, max_seq_len)
        self.linear = nn.Linear(d_model, vocab_size, bias=False)


    def forward(self, inputs, enc_output, seq_embedding, context_attn_mask = None):

        embedding = seq_embedding(inputs)
        output =  embedding + self.pos_embedding(embedding)

        self_attention_padding_mask = padding_mask(inputs, inputs)
        seq_mask = sequence_mask(inputs).to(self.device)
        self_attn_mask = torch.gt((self_attention_padding_mask+seq_mask), 0 )

        for decoder in self.decoder_layers:
            output = decoder(output, enc_output,self_attn_mask,context_attn_mask)

        output = self.linear(output)
        return output


In [10]:
class Transformer(nn.Module):
    #Build transformer model

    def __init__(self,
                 vocab_size,
                 max_len,
                 device,
                 num_layers = 6,
                 stack_layers= 6,
                 d_model = 512,
                 num_heads = 8,
                 ffn_dim = 2048,
                 dropout = 0.2):

        super(Transformer, self).__init__()
        
        self.device = device
        
        self.encoder = Encoder(vocab_size, max_len, num_layers, d_model, num_heads, ffn_dim, dropout)
        self.decoder = Decoder(vocab_size, max_len, num_layers, d_model, num_heads, ffn_dim, dropout)
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.linear = nn.Linear(d_model, vocab_size, bias = False)
        #self.linear = nn.Softmax(dim = 2)

        

    def forward(self, src_seq, dec_tgt,dec_in):                           

        context_attn_mask_dec = padding_mask(dec_tgt, src_seq)
        
        en_output = self.encoder(src_seq, self.embedding)
        
        dec_output = self.decoder(dec_tgt, en_output, self.embedding, context_attn_mask_dec)
        
        #gen_output = self.generator(image_in, en_output)
        
        #return dec_output,gen_output
        
        return dec_output

In [11]:
import torch

In [12]:
def padding_mask(seq_k, seq_q):

    # pad sentence
    len_q = seq_q.size(1)
    pad_mask = seq_k.eq(0)
    pad_mask = pad_mask.unsqueeze(1).expand(-1,len_q,-1)

    return pad_mask

In [13]:
def padding_mask(txt_in, img_in):
    len_txt = txt_in.shape[1]
    len_img = img_in.shape[1]
    pad_mask = txt_in.eq(0)
    len_tot = len_txt + len_img
    pad_mask = pad_mask.unsqueeze(1).expand(-1, len_tot, -1)
    return pad_mask

In [14]:
inputs = torch.tensor([[1,2,3,0,0,0],
                       [3,4,0,0,0,0],
                       [3,0,0,0,0,0],
                       [4,5,6,7,0,0]])
img = torch.randn(4,5,10)

In [15]:
inputs.shape

torch.Size([4, 6])

In [17]:
t = padding_mask(inputs,img)

In [18]:
p = torch.zeros(4,11,11)
p[:,:,:6] = padding_mask(inputs,img)

In [19]:
p

tensor([[[0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.]],

        [[0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
         [0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
         [0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
         [0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
         [0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
         [0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
         [0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        

In [53]:
padding_mask = torch.zeros(4,6,6)

In [54]:
padding_mask

tensor([[[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]]])

In [58]:
def sequence_mask(seq):

    batch_size , seq_len = seq.size()
    mask = torch.triu(torch.ones((seq_len, seq_len),dtype = torch.uint8),
                      diagonal = 1)
    mask = mask.unsqueeze(0).expand(batch_size, -1,-1)
    return mask

In [59]:
sequence_mask(inputs)==1

tensor([[[False,  True,  True,  True,  True,  True],
         [False, False,  True,  True,  True,  True],
         [False, False, False,  True,  True,  True],
         [False, False, False, False,  True,  True],
         [False, False, False, False, False,  True],
         [False, False, False, False, False, False]],

        [[False,  True,  True,  True,  True,  True],
         [False, False,  True,  True,  True,  True],
         [False, False, False,  True,  True,  True],
         [False, False, False, False,  True,  True],
         [False, False, False, False, False,  True],
         [False, False, False, False, False, False]],

        [[False,  True,  True,  True,  True,  True],
         [False, False,  True,  True,  True,  True],
         [False, False, False,  True,  True,  True],
         [False, False, False, False,  True,  True],
         [False, False, False, False, False,  True],
         [False, False, False, False, False, False]],

        [[False,  True,  True,  True,  T

In [60]:
a = torch.tensor([[[3,4,5,4],
                   [1,2,3,4],
                   [5,6,7,8]],
                  [[0,0,0,0],
                   [9,9,9,9],
                   [5,5,5,5]]])

In [61]:
a.device

device(type='cpu')

In [62]:
b = torch.zeros(2,2,3,4)

In [63]:
a.unsqueeze(1).expand_as(b)

tensor([[[[3, 4, 5, 4],
          [1, 2, 3, 4],
          [5, 6, 7, 8]],

         [[3, 4, 5, 4],
          [1, 2, 3, 4],
          [5, 6, 7, 8]]],


        [[[0, 0, 0, 0],
          [9, 9, 9, 9],
          [5, 5, 5, 5]],

         [[0, 0, 0, 0],
          [9, 9, 9, 9],
          [5, 5, 5, 5]]]])