In [1]:
import torch
import torch.nn as nn
import numpy as np

In [None]:
def scaledDotProductAtt(nn.module):
  def __init__(self,scale):
    super().__init__()
    self.scale = scale
    self.softmax = nn.softmax(dim=2)
  def forward(self,q,k,v,mask=None):
    u = torch.bmm(q,k.transpose(1,2))
    u = u / self.scale
    if mask is not None:
      u = u.masked_fill(mask,-np.inf)
    att = self.softmax(u)
    output = torch.bmm(att,v)
    return att, output

def multiHeadAtt(nn.module):
  def __init__(self,n_head,d_k_,d_v_,d_k.d_v,d_o):
    super().__init__()
    self.head = n_head
    self.d_k = d_k
    self.d_v = d_v
    self.fc_q = nn.Linear(d_q_,self.d_q*n_head)
    self.fc_k = nn.Linear(d_k_,self.d_k*n_head)
    self.fc_v = nn.Linear(d_v_,self.d_v*n_head)
    self.fc_o = nn.Linear(d_v*n_head,self.d_od)
    self.attention = scaledDotProductAtt(scale=d_k**0.5)

  def forward(self,q,k,v,mask=None):
    n_head, d_q, d_k, d_v = self.n_head, self.d_k, self.d_k, self.d_v
    batch, n_q, d_q = q.size()
    batch, n_k, d_k = k.size()
    batch, n_v, d_v = v.size()

    q = self.fc_q(q)
    k = self.fc_k(k)
    v = self.fc_v(v)

    q = q.view(batch, n_q, n_head, d_q).permute(2, 0, 1, 3).contiguous().view(-1, n_q, d_q)
    k = k.view(batch, n_k, n_head, d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_k, d_k)
    v = v.view(batch, n_v, n_head, d_v).permute(2, 0, 1, 3).contiguous().view(-1, n_v, d_v)

    if mask is not None:
      mask = mask.repeate(n_head,1,1)
    attn, output = self.attention(q,k,v,mask)
    output = output.view(n_head, batch, n_q, d_v).permute(1, 2, 0, 3).contiguous().view(batch, n_q, -1)
    output = self.fc_o(output)
    return attn, output

class selfAtt(nn.module):
  def __init__(self,n_head,d_x,d_k,d_v,d_o):
    self.wq = nn.Parameter(torch.Tensor(d_x, d_k))
    self.wk = nn.Parameter(torch.Tensor(d_x, d_q))
    self.wv = nn.Parameter(torch.Tensor(d_x, d_v))
    self.mha = multiHeadAtt(n_head,d_k_=d_k,d_k=d_k,d_v_=d_v,d_v=d_v,d_o=d_o)

    self.init_parameters()

    def init_parameters(self):
        for param in self.parameters():
            stdv = 1. / np.power(param.size(-1), 0.5)
            param.data.uniform_(-stdv, stdv)


    def forward(self,x,mask=None):
      q = torch.matmul(x,self.wq)
      k = torch.matmul(x,self.wk)
      v = torch.matmul(x,self.wv)

      att, output = self.mha(q,k,v,mask)

      return att, output
