In [1]:
import torch
import torch.nn as nn
import math

In [None]:
class TokenEmbedding(nn.Module):

    def __init__(self, in_dim, out_dim):
        super(TokenEmbedding, self).__init__()
        self.emb = nn.Line(in_dim, out_dim)

    def forward(self, x):
        return self.emb(x)


class PositionalEmbedding(nn.Module):

    def __init__(self, max_len, model_dims, device):
        super(PositionalEmbedding, self).__init__()
        self.emb = torch.zeros((max_len, model_dims, device=device)
        pos = torch.arange(0, max_len, device=device)
        pos = pos.double().unsqueeze(dim=1)
        i = torch.arange(0, model_dims, step=2, device=device)

        self.emb[:, 0::2] = torch.sin(pos / (10000 ** (i/model_dims)))
        self.emb[:, 1::2] = torch.sin(pos / (10000 ** (i/model_dims)))

    def forward(self, x):
        batch, seqlen, model_dim = x.size()
        return self.emb[:seqlen, :]



class Embedding(nn.Module):

    def __init__(self, in_dim,  max_len, model_dims, device, dropprob=0.5):
        super(Embedding, self).__init__()
        self.token_embedding = TokenEmbedding(in_dim, model_dims)
        self.pos_emb =PositionalEmbedding(max_len, model_dims, device)
        self.drop = nn.Dropout(dropprob)

    def forward(self, x):
        tkn = self.token_embedding(x)
        pos = self.pos_emb(x)
        return self.drop(tkn + pos)


class LayerNorm(nn.Module):

    def __init__(self):
        super(LayerNorm, self).__init__()

    def forward(self):
        pass

class AttentionLayer(nn.Module):
    
    def __init__(self):
        super(AttentionLayer, self).__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask=None, eps=1e-16):
        
        batch_size, num_head, seq_len, d_tensor = k.size()
        k_t = k.transpose(2,3)

        score = (q @ k_t) / math.sqrt(d_tensor)
        if mask is not None:
            score.masked_fill(mask == 0, eps)

        score = self.softmax(score)
        v = v @ score
        return v, score


class FeedForwardLayer(nn.Module):

    def __init__(self, model_dim, hidden_dim, dropprob=0.0):
        super(FeedForwardLayer, self).__init__()
        self.lin1 = nn.Linear(model_dim, hidden_dim)
        self.drop = nn.Dropout(dropprob)
        self.act = nn.GELU()
        self.lin2 = nn.Linear(hidden_dim, model_dim)

    def forward(self, x):
        x = self.lin1(x)
        x = self.act(x)
        x = self.drop(x)
        return self.lin2(x)

class MultiHeadAttn(nn.Module):

    def __init__(self, num_head, model_dim):
        super(MultiHeadAttn, self).__init__()
        self.head = num_head
        self.attn = AttentionLayer()
        self.Q = nn.Linead(model_dim, model_dim)
        self.K = nn.Linead(model_dim, model_dim)
        self.V = nn.Linead(model_dim, model_dim)
        self.out = nn.Linead(model_dim, model_dim)

    def forward(self, q, k, v, mask=None):
        q, k, v = Q(q), K(k), V(v)
        q, k, v = self.split(q), self.split(k), self.split(v)
        o, score = self.attn(q,k, v)
        o = self.concat(o)
        o = self.out(o)
        return o
        

    def split(self, x):
        batch_size, length, model_dim = x.size()
        d_dim = model_dim // self.head
        return x.view(batch_size, length, self.head, d_dim).transpose(1, 2)

    def concat(self, x):
        batch_size, head_num, length, d_dim = x.size()
        model_dim = head_num * d_dim
        return x.transpose(1,2).contiguous().view(batch, length, model_Dim)


class EncodeLayer(nn.Module):

    def __init__(self, num_head, model_dim, hidden_dim, dropprob=0.0):
        super(EncodeLayer, self).__init__()
        self.Attn = MultiHeadAttn(num_head, model_dim)
        self.norm1 = LayerNorm(model_dim)
        self.drop1 = nn.Dropout2d(dropprob)

        self.FF = FeedForwardLayer(model_dim, hidden_dim, dropprob)
        self.norm2 = LayerNorm(model_dim)
        self.drop2 = nn.Dropout2d(dropprob)



    def forward(self, x, mask=None):
        x_ = x
        x = self.Attn(q=x, k=x, v=x, mask)
        x = self.norm1(x + x_)
        x = self.drop1(x)

        x_ = x
        x = self.FF(x)
        x = self.norm2(x + x_)
        x = self.drop2(x)

        return x


class DecodeLayer(nn.Module):

    def __init__(self, num_head, model_dim, hidden_dim, dropprob=0.0):
        super(DecodeLayer, self).__init__()
        self.Attn1 = MultiHeadAttn(num_head, model_dim)
        self.norm1 = LayerNorm(model_dim)
        self.drop1 = nn.Dropout2d(dropprob)

        self.Attn2 = MultiHeadAttn(num_head, model_dim)
        self.norm2 = LayerNorm(model_dim)
        self.drop2 = nn.Dropout2d(dropprob)

        self.Attn3 = MultiHeadAttn(num_head, model_dim)
        self.norm3 = LayerNorm(model_dim)
        self.drop3 = nn.Dropout2d(dropprob)

    def forward(self, dec, enc, t_mask, s_t_mask):
        dec_ = dec
        dec = self.Attn1(q=dec, k=dec, v=dec, t_mask)
        dec = delf.norm1(dec + dec_)
        dec = self.drop1(dec)

        if enc is not None:
            dec_ = dec
            dec = self.Attn1(q=dec, k=enc, v=enc, s_t_mask)
            dec = delf.norm2(dec + dec_)
            dec = self.drop2(dec)


        dec_ = dec
        dec = self.Attn2(q=dec, k=dec, v=dec, t_mask)
        dec = delf.norm2(dec + dec_)
        dec = self.drop2(dec)

        return dec



class Encoder(nn.Module):

    def __init__(self):
        super(EncodeLayer, self).__init__()

    def forward(self):
        pass


class Decoder(nn.Module):

    def __init__(self):
        super(DecodeLayer, self).__init__()

    def forward(self):
        pass



class Transformer(nn.Module):

    def __init__(self):
        super(Transformer, self).__init__()

    def forward(self):
        pass