In [2]:
import torch.nn as nn
import torch
from torch.autograd import Variable
import math
import numpy as np
import torch.nn.functional as F

#import matplotlib as plt

In [11]:
class Embeddings(nn.Module):
    def __init__(self, dim, vocab_size):
        super(Embeddings,self).__init__() # == super().__init__() 
        self.lut = nn.Embedding(vocab_size,dim)
        self.d_model = dim
    
    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

In [12]:
attention_size = (1,6,6)
subseq_mask = np.triu(np.ones(attention_size),k=1)
subseq_mask

array([[[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]]])

In [13]:
def subsequent_mask(size):
    attention_size = (1,size,size)
    subseq_mask = np.triu(np.ones(attention_size),k=1)
    return torch.from_numpy(1-subseq_mask)

In [14]:
trid = subsequent_mask(6)
trid,trid.shape

(tensor([[[1., 0., 0., 0., 0., 0.],
          [1., 1., 0., 0., 0., 0.],
          [1., 1., 1., 0., 0., 0.],
          [1., 1., 1., 1., 0., 0.],
          [1., 1., 1., 1., 1., 0.],
          [1., 1., 1., 1., 1., 1.]]], dtype=torch.float64),
 torch.Size([1, 6, 6]))

In [15]:
trid[0].shape

torch.Size([6, 6])

In [16]:
embedding = nn.Embedding(10,3)
type(embedding)


torch.nn.modules.sparse.Embedding

In [17]:
embedding.weight

Parameter containing:
tensor([[ 0.6451, -0.1286, -0.6585],
        [-1.1187, -0.1429,  0.0454],
        [ 0.7862, -0.8243, -2.0834],
        [-0.3729,  0.5265, -0.6919],
        [-0.1460,  0.6052, -1.1296],
        [-0.0972,  0.1188,  0.3157],
        [ 0.0950, -0.7536, -0.9331],
        [-0.6951, -0.9528, -0.5592],
        [-1.2246, -0.5633,  1.7734],
        [ 0.7964,  1.1891,  0.5069]], requires_grad=True)

In [9]:
inputs = torch.LongTensor([[1,2,3,4],[0,6,7,8]])
inputs.shape

torch.Size([2, 4])

In [18]:
emb = embedding(inputs)

In [19]:
emb

tensor([[[-1.1187, -0.1429,  0.0454],
         [ 0.7862, -0.8243, -2.0834],
         [-0.3729,  0.5265, -0.6919],
         [-0.1460,  0.6052, -1.1296]],

        [[ 0.6451, -0.1286, -0.6585],
         [ 0.0950, -0.7536, -0.9331],
         [-0.6951, -0.9528, -0.5592],
         [-1.2246, -0.5633,  1.7734]]], grad_fn=<EmbeddingBackward0>)

In [20]:
emb.shape

torch.Size([2, 4, 3])

In [21]:
x = Variable(torch.LongTensor([[1,2,3,4],[0,6,7,8]]))

In [22]:
x

tensor([[1, 2, 3, 4],
        [0, 6, 7, 8]])

In [23]:
type(x),type(inputs)

(torch.Tensor, torch.Tensor)

In [24]:
x = Variable(torch.LongTensor([[1,2,3,4],[0,6,7,8]]))
d_model =512
vocab_size = 10000
embs = Embeddings(d_model,vocab_size= vocab_size)
inputs_emb = embs(x)
query = key = value = inputs_emb

In [26]:
inputs_emb.shape[-1],inputs_emb.size(),inputs_emb.size()[-1]

(512, torch.Size([2, 4, 512]), 512)

In [27]:
def attention(query, key, value, mask = None, dropout = None):
    d_k = query.size()[-1]
    scores = torch.matmul(query, key.transpose(-2,-1))/math.sqrt(d_k) #key last dim exchange with last 2nd dim
    print(scores.size())
    if mask is not None:
        scores = scores.masked_fill(mask==0, -1e9)
    attn = F.softmax(scores, dim = -1)
    print(attn.size())
    if dropout is not None:
        attn = dropout(attn)
    return torch.matmul(attn, value), attn

In [28]:
query = key = value = inputs_emb

In [29]:
value.size()

torch.Size([2, 4, 512])

In [30]:
values, attn = attention(query,key,value)

torch.Size([2, 4, 4])
torch.Size([2, 4, 4])


In [31]:
values.shape

torch.Size([2, 4, 512])

In [32]:
attn.shape

torch.Size([2, 4, 4])

In [33]:
import copy
def clones(module, n = 1):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])


In [34]:
class MultiHeadAttention(nn.Module):
    def __init__(self, head, embedding_dim, dropout = 0.1):
        super().__init__()
        assert embedding_dim % head == 0
        self.d_k =   // head
        self.head = head
        self.linears = clones(nn.Linear(embedding_dim, embedding_dim), 4)
        self.attn = None
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query, key, value, mask = None):
        # if mask
        if mask is not None:
            mask = mask.unsqueeze(0) # expand the dim of mask? represent nth head in multihead
        batch_size = query.size()[0]
        query, key, value = [model(x).view(batch_size, -1, self.head,self.d_k).transpose(1,2) for model, x in zip(self.linears, (query, key, value))] # transpose -1 means length of sentence, finally, last two dims are length of sentence and word emb dim
        x, self.attn = attention(query, key, value, mask = mask, dropout = self.dropout)
        x = x.transpose(1,2).contiguous().view(batch_size, -1, self.d_k*self.head)
        return self.linears[-1](x)            
        

In [35]:
head = 8
embedding_dim = 512
dropout = 0.2
#input
query = key = value = inputs_emb
mask = Variable(torch.zeros(8,4,4)) # number of head, matrix dim
mask1 = mask.unsqueeze(0)

In [36]:
linears = clones(nn.Linear(embedding_dim, embedding_dim), 4)
batch_size = query.size()[0]

In [37]:
batch_size, query.shape

(2, torch.Size([2, 4, 512]))

In [38]:
aa = linears[0](query)

In [39]:
bb = aa.view(batch_size, -1, head,embedding_dim//head)

In [40]:
bb.shape

torch.Size([2, 4, 8, 64])

In [41]:
query1, key1, value1 = [model(x).view(batch_size, -1, head,embedding_dim//head).transpose(1,2) for model, x in zip(linears, (query, key, value))] 

In [58]:
query1.shape

torch.Size([2, 8, 4, 64])

In [29]:
query.shape,query1.shape

(torch.Size([2, 4, 512]), torch.Size([2, 8, 4, 64]))

In [30]:
mask1.shape

torch.Size([1, 8, 4, 4])

In [49]:
def attention(query, key, value, mask = None, dropout = None):
    d_k = query.size()[-1]
    #print(query.size())
    scores = torch.matmul(query, key.transpose(-2,-1))/math.sqrt(d_k) #key last dim exchange with last 2nd dim
    #print(scores.size())
    if mask is not None:
        #print(mask.size())
        scores = scores.masked_fill(mask==0, -1e9)
        #print(scores)
    attn = F.softmax(scores, dim = -1)
    #print(attn.size())
    if dropout is not None:
        attn = dropout(attn)
    return torch.matmul(attn, value), attn

In [32]:
query1.shape == key1.shape

True

In [33]:
x, attn = attention(query1, key1, value1, mask = mask1)

torch.Size([2, 8, 4, 64])
torch.Size([2, 8, 4, 4])
torch.Size([1, 8, 4, 4])
tensor([[[[-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

         [[-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

         [[-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

         [[-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09

In [34]:
len(linears)

4

In [35]:
query.shape

torch.Size([2, 4, 512])

In [36]:

mask.shape

torch.Size([8, 4, 4])

In [37]:
mha = MultiHeadAttention(head, embedding_dim, dropout)

In [38]:
mha_result = mha(query,key,value,mask)

torch.Size([2, 8, 4, 64])
torch.Size([2, 8, 4, 4])
torch.Size([1, 8, 4, 4])
tensor([[[[-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

         [[-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

         [[-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

         [[-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09

In [39]:
mha_result

tensor([[[ 2.6542,  5.3653,  3.0327,  ...,  1.3045,  6.9633, -1.8286],
         [ 3.8160, -0.9026,  7.7298,  ...,  0.1826,  4.8034, -0.0115],
         [ 3.6228,  4.8580,  5.5297,  ...,  4.4123,  4.6706, -0.6531],
         [ 1.3359,  4.3676,  4.0880,  ...,  1.5854,  3.6474, -1.6719]],

        [[ 1.0590, -4.1020, -0.7624,  ..., -8.0675,  1.7987,  4.2940],
         [-3.5120, -6.5496,  0.5052,  ..., -7.4359, -0.8139,  3.2296],
         [-3.1542, -5.1074,  2.7063,  ..., -8.7210,  0.7330,  3.4519],
         [-0.1865,  0.7082,  0.8235,  ..., -7.9318, -1.3038,  0.4232]]],
       grad_fn=<ViewBackward0>)

In [40]:
mha_result.shape

torch.Size([2, 4, 512])

In [41]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout = 0.1):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff)
        self.w2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, attn_output):
        return self.w2(self.dropout(F.relu(self.w1(attn_output))))



In [42]:
d_model = 512
d_ff = 128
x = mha_result

In [64]:
ff = PositionWiseFeedForward(d_model,d_ff)
ffn_result = ff(x)

In [65]:
ffn_result.shape

torch.Size([2, 4, 512])

In [59]:
ones = nn.Parameter(torch.ones(1))
ones.shape

torch.Size([1])

In [47]:
class LayerNorm(nn.Module):
    def __init__(self, d_k, esp = 1e-6):
        super().__init__()
        self.d_k = d_k
        self.esp = esp
        self.a1 = nn.Parameter(torch.ones(d_k))
        self.b1 = nn.Parameter(torch.zeros(d_k))
    def forward(self, x):
        x_mean = x.mean(-1,keepdim=True) # word embedding mean
        x_std = x.std(-1,keepdim=True)
        return self.a1*(x-x_mean)/(x_std+self.esp) +self.b1

In [61]:
features = d_model = 512
eps = 1e-6

In [66]:
ln = LayerNorm(d_model,eps)
normalized_r = ln(ffn_result)

In [67]:
normalized_r.shape

torch.Size([2, 4, 512])

In [1]:
"/leet/code".split('/')

['', 'leet', 'code']

In [8]:
class SubLayerConnection(nn.Module):
    def __init__(self, d_model, dropout = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(p = dropout)
        self.d_model = d_model
        self.norm = LayerNorm(d_model)
    def forward(self, sublayer, x):
        return x + self.dropout(sublayer(self.norm(x)))
    
    

In [44]:
x = Variable(torch.LongTensor([[1,2,3,4],[0,6,7,8]]))
d_model =512
vocab_size = 10000
embs = Embeddings(d_model,vocab_size= vocab_size)
inputs_emb = embs(x)
query = key = value = inputs_emb

In [45]:
size = d_model = 512
head = 8
dropout = 0.2
mask = Variable(torch.zeros(8,4,4))
self_attn = MultiHeadAttention(head, d_model)

In [48]:
sublayer = lambda x:self_attn(x,x,x, mask)
sc = SubLayerConnection(d_model, dropout)
sc_result = sc(sublayer, inputs_emb)
sc_result.shape

torch.Size([2, 8, 4, 64])
torch.Size([2, 8, 4, 4])
torch.Size([1, 8, 4, 4])
tensor([[[[-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

         [[-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

         [[-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

         [[-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
          [-1.0000e+09

torch.Size([2, 4, 512])