In [1]:
import torch
from torch import nn
import math

In [2]:
seq_num = 10
feature_num = 24
word_num = 5
batch_size = 15
dic_num = 20

In [3]:
q = torch.rand(word_num, feature_num, dtype = torch.float32)
k = torch.rand(word_num, feature_num, dtype = torch.float32)
v = torch.rand(word_num, feature_num, dtype = torch.float32)
word = torch.randint(0,20,size=(batch_size,seq_num))

In [4]:
class InputEmbedding(nn.Module):
    
    def __init__(self, word_num, feature_num):
        super().__init__()
        self.w = word_num
        self.d = feature_num
        self.embedding = nn.Embedding(word_num, feature_num)
    
    def forward(self, X):
        return self.embedding(X) * math.sqrt(feature_num)

In [5]:
input_emd = InputEmbedding(dic_num, feature_num)

In [6]:
embedd = input_emd(word)

In [7]:
class PositionEmbedding(nn.Module):
    
    def __init__(self, word_num,feature_num):
        super().__init__()
        self.zeros = torch.zeros(word_num, feature_num)
        sfeature_num = torch.arange(0,feature_num,dtype=torch.float)        
        for i in range(word_num):
            self.zeros[i, 0::2] = torch.sin(i / torch.pow(10000, 2 * sfeature_num / math.sqrt(feature_num)))[0::2]
            self.zeros[i, 1::2] = torch.cos(i / torch.pow(10000, 2 * sfeature_num / math.sqrt(feature_num)))[1::2]
    def forward(self, x):
        return self.zeros + x

In [8]:
class revised_PositionEmbedding(nn.Module):
    
    def __init__(self, word_num,feature_num, p=0.2):
        super().__init__()
        even_i = torch.pow(10000, 2 * torch.arange(0, feature_num, 2, dtype=torch.float)/math.sqrt(feature_num))
        odd_i = torch.pow(10000, 2 * torch.arange(1, feature_num, 2, dtype=torch.float)/math.sqrt(feature_num))
        tword = torch.arange(0, word_num).unsqueeze(1)
        self.pe = torch.zeros(word_num, feature_num)
        self.pe[:, 0::2] = torch.sin(tword / even_i)
        self.pe[:, 1::2] = torch.cos(tword / odd_i)
        
        self.dropout = nn.Dropout(p)
    def forward(self, x):
        x = (self.pe + x).detach()
        return self.dropout(x)

In [9]:
position_emd = PositionEmbedding(word_num = seq_num,feature_num = feature_num)

In [10]:
total_emb = position_emd(embedd)

In [11]:
revised_pos_emd = revised_PositionEmbedding(word_num = seq_num,feature_num = feature_num)

In [12]:
revised_total_emd = revised_pos_emd(embedd)

In [13]:
class LayerNormalization(nn.Module):
    
    def __init__(self, eps = 10 ** -6):
        super().__init__()
        self.alpha = nn.Parameter(torch.rand(1))
        self.beta = nn.Parameter(torch.rand(1))
        self.eps = eps
    
    def forward(self,x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha * (x-mean) / (std**2 + self.eps)**0.5 + self.beta

In [14]:
layer_norm = LayerNormalization()

In [15]:
normed_embed = layer_norm(revised_total_emd)

In [16]:
class FCLayer(nn.Module):
    
    def __init__(self, input_dim, output_dim, p =0.2):
        
        super().__init__()
        self.seq = nn.Sequential(
            nn.Linear(input_dim, 48),
            nn.ReLU(),
            nn.Dropout(p),
            nn.Linear(48,output_dim)
        )
        
    def forward(self,x):
        return self.seq(x)

In [17]:
fc = FCLayer(feature_num, feature_num)

In [18]:
fc_embed = fc(normed_embed)

In [19]:
fc_embed.shape

torch.Size([15, 10, 24])

In [20]:
h = 3

In [21]:
twq = torch.rand(feature_num, feature_num, dtype = torch.float32)
twk = torch.rand(feature_num, feature_num, dtype = torch.float32)
twv = torch.rand(feature_num, feature_num, dtype = torch.float32)        
thead_num = h
tfeature_num = feature_num

In [25]:
class MultiHead(nn.Module):
    
    def __init__(self, head_num, feature_num, p=0.2):
        super().__init__()
        self.wq = nn.Linear(feature_num, feature_num)
        self.wk = nn.Linear(feature_num, feature_num)
        self.wv = nn.Linear(feature_num, feature_num)
        self.do = nn.Dropout(p=p)
        self.head_num = head_num
        self.feature_num = feature_num
        self.dk = feature_num // self.head_num
    
    @staticmethod
    def attention(q, k, v, dk ,mask):
        qk = torch.matmul(q, k.transpose(-1,-2)) / math.sqrt(dk)
        if mask is not None:
            qk.masked_fill_(mask==0, -1e9)
        attention = qk.softmax(dim=-1)
        attention = torch.matmul(attention, v)
        return attention
    
    def forward(self,q,k,v, mask):
        q = self.wq(q)
        k = self.wk(k)        
        v = self.wv(v)
        q = q.view(q.shape[0], q.shape[1], self.head_num, self.dk).transpose(1,2)
        k = k.view(k.shape[0], k.shape[1], self.head_num, self.dk).transpose(1,2) 
        v = v.view(v.shape[0], v.shape[1], self.head_num, self.dk).transpose(1,2)
        attention = MultiHead.attention(q,k,v, self.dk, mask)
        attention = attention.transpose(1,2).contiguous().view(attention.shape[0], -1, self.feature_num)
        return attention

In [26]:
class ResidualConnection(nn.Module):
    
    def __init__(self, p=0.2):
        super().__init__()
        self.dp = nn.Dropout(p=p)
        self.norm = LayerNormalization()
    
    def forward(self,x,sublayer):
        return x + self.dp(sublayer(self.norm(x)))

In [27]:
class EncoderBlock(nn.Module):
    
    def __init__(self, attention_layer, ff_layer):
        
        super().__init__()
        self.attention_layer = attention_layer
        self.ff_layer = ff_layer
        self.block = nn.ModuleList([ResidualConnection() for i in range(2)])
    
    def forward(self, x, mask):
        x = self.block[0](x, lambda x: self.attention_layer(x,x,x,mask))
        x = self.block[1](x, self.ff_layer)
        return x

In [28]:
class Encoder(nn.Module):
    def __init__(self, layer, n=6):
        self.layer = layer
        self.n = n
        self.norm = LayerNormalization()
    def forward(self, x, mask):
        for b in self.layer:
            x = b(x, mask)
        return self.norm(x)

In [33]:
class DecoderBlock(nn.Module):
    
    def __init__(self, self_attention_layer, cross_attention_layer, ff_layer):
        super().__init__()
        self.self_attention_layer = self_attention_layer
        self.cross_attention_layer = cross_attention_layer
        self.ff_layer = ff_layer
        self.block = nn.ModuleList([ResidualConnection() for i in range(3)])
        
    def forward(self, encoder_output, src_mask, tgt_mask):
        
        x = self.block[0](x, lambda x: self.self_attention_layer(x,x,x, src_mask))
        x = self.block[1](x, lambda x: self.cross_attention_layer(x,encoder_output,encoder_output, src_mask))
        x = self.block[2](x, self.ff_layer)
        return x

In [30]:
class Decoder(nn.Module):
    
    def __init__(self, layer):
        self.layer = layer
        self.norm = LayerNormalization()
    
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for l in self.layer:
            x = l(x, encoder_output, src_mask, tgt_mask)
        
        return self.norm(x)

In [31]:
class ProjectionLayer(nn.Module):
    
    def __init__(self, feature_num, dic_num):
        
        self.linear = nn.Linear(feature_num, dic_num)
    
    def forward(self,x):
        x = self.linear(x)
        return torch.log_softmax(x, dim=-1)

In [1]:
class Transformer(nn.Module):
    
    def __init__(self, src_emb, tgt_emb, src_pos_emb, tgt_pos_emb, encoder, decoder, projection_layer):
        super().__init__()
        self.src_emb = src_emb
        self.tgt_emb = tgt_emb
        self.src_pos_emb = src_pos_emb
        self.tgt_pos_emb = tgt_pos_emb
        self.encoder = encoder
        self.decoder = decoder
        self.projection = projection_layer
    
    def encode(self, src, src_mask):
        src = self.src_emb(src)
        src = self.src_pos_emb(src)
        return self.encoder(x, src_mask)
    
    def decoder(self, tgt, src, src_mask, tgt_mask):
        tgt = self.tgt_emb(tgt)
        tgt = self.tgt_pos_emb(tgt)
        return self.decoder(tgt, src, src_mask, tgt_mask)
    
    def projection(self,x):
        return self.projection(x)

NameError: name 'nn' is not defined

In [36]:
def initTransformer(src_voc_num, tgt_voc_num, d_model=24, n=6, head_num=3):
    src_emb = InputEmbedding(src_voc_num, d_model)
    tgt_emb = InputEmbedding(tgt_voc_num, d_model)
    src_pos_emb = revised_PositionEmbedding(src_voc_num, d_model)
    tgt_pos_emb = revised_PositionEmbedding(pos_voc_num, d_model)
    projection = ProjectionLayer(d_model, tgt_voc_num)
    encoder = []
    for i in range(n):
        mh = MultiHead(head_num, d_model)
        fc = FClayer(d_model, d_model)
        encoderblock = EncoderBlock(mh, fc)
        encoder.append(encoderblock)
    encoder = Encoder(nn.ModuleList(encoder))
    
    decoder = []
    for i in range(n):
        self_attn = MultiHead(head_num, d_model)
        cross_attn = MultiHead(head_num, d_model)
        fc = FClayer(d_model, d_model)
        decoderblock = DecoderBlock(self_attn, cross_attn, fc)
        decoder.append(encoderblock)
    
    decoder = Decoder(nn.ModuleList(decoder))
    
    transformer = Transformer(src_emb, tgt_emb, src_pos_emb, tgt_pos_emb, encoder, decoder, projection)
    
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return transformer