In [None]:
import os
from os.path import exists
import torch
import torch.nn as nn
from torch.nn.functional import log_softmax, pad
import math
import copy
import time
from torch.optim.lr_scheduler import LambdaLR
import pandas as pd
from torch.utils.data import DataLoader
import warnings
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import ipdb
RUN_EXAMPLES = True

def clones(module,N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

# mask -------------------------------
def subsequent_mask(size):
    ipdb.set_trace()
    attn_shape = (1,size,size)
    # torch.triu创建一个上三角矩阵
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
        torch.uint8
    )
    return subsequent_mask == 0


# Attention ----------------------------
def attention(query,key,value,mask=None,dropout=None):
    # query b h l d_k
    d_k = query.size(-1)
    # 为什么？ 看query和key的形状，这里要看形状
    # score b h l l
    scores = torch.matmul(query,key.transpose(-2,-1))/math.sqrt(d_k) # 1 8 10 10
    if mask is  not None: # mask 1 1 1 10
        scores = scores.masked_fill(mask == 0,-1e9) # 返回新的张量，如果修改原始张量则是用in-place版本masked_fill_()
    p_attn = scores.softmax(dim=-1) # 以最后一个作softmax，形状b h l l 1 8 10 10
    if dropout is not None:
        p_attn = dropout(p_attn) #这里为啥要dropout
    return torch.matmul(p_attn,value),p_attn


class MultiHeadedAttention(nn.Module):
    def __init__(self,h,d_model,dropout=0.1):
        super(MultiHeadedAttention,self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model,d_model),4) #W_q W_k W_v W_o
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self,query,key,value,mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 先project
        query,key,value = [ 
            # h x d_model
            # lin(x)  b l d_model  b l h d_k 将特征维度按照头切分 -> b h l dk
            lin(x).view(nbatches,-1,self.h,self.d_k).transpose(1,2)
            for lin, x in zip(self.linears,(query,key,value))
        ]

        # query b l h d_k -> b h l d_k
        
        x,self.attn = attention(query,key,value,mask=mask,dropout=self.dropout)

        x = (
            x.transpose(1,2)
            .contiguous()
            .view(nbatches,-1,self.h*self.d_k)
        )
        del query
        del key
        del value
        return self.linears[-1](x)




class EncoderDecoder(nn.Module):
    def __init__(self,encoder,decoder,src_embed,tgt_embed,generator):
        super(EncoderDecoder,self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator

    def forward(self,src,tgt,src_mask,tgt_mask):
        ipdb.set_trace()
        return self.decode(self.encode(src,src_mask),src_mask,tgt,tgt_mask)

    def encode(self,src,src_mask):
        return self.encoder(self.src_embed(src),src_mask)

    # memory是干啥的
    def decode(self,memory,src_mask,tgt,tgt_mask):
        ipdb.set_trace()
        return self.decoder(self.tgt_embed(tgt),memory,src_mask,tgt_mask)

    
class Generator(nn.Module):

    def __init__(self,d_model,vocab):
        super(Generator,self).__init__()
        self.proj = nn.Linear(d_model,vocab)

    def forward(self,x):
        ipdb.set_trace()
        return log_softmax(self.proj(x),dim=-1) 


# --------------------------------Encoder-----------------------------------
class Encoder(nn.Module):
    # encoder 是6个 layer的堆叠
    def __init__(self,layer,N):
        super(Encoder,self).__init__()
        self.layers = clones(layer,N)
        self.norm = LayerNorm(layer.size) ## 为什么是layer.size?????,这是啥玩意

    def forward(self,x,mask):
        for layer in self.layers:
            x = layer(x,mask)
        return self.norm(x)

class LayerNorm(nn.Module):
    def __init__(self,features,eps=1e-6):
        super(LayerNorm,self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self,x):
        mean = x.mean(-1,keepdim=True) #为啥是-1？
        std = x.std(-1,keepdim=True)
        return self.a_2*(x-mean)/(std+self.eps)+self.b_2

    
# 残差连接
class SublayerConnection(nn.Module):

    def __init__(self,size,dropout):
        super(SublayerConnection,self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout) # drop_out 是一个数值丢弃率

    def forward(self,x,sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

#---->norm-->attention->add->norm->mlp->add
#   |                    |  |              |
#    --------------------    --------------
class EncoderLayer(nn.Module):
    def __init__(self,size,self_attn,feed_forward,dropout):
        super(EncoderLayer,self).__init__()
        self.self_attn =  self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size,dropout),2)
        self.size = size
    
    def forward(self,x,mask):
        x = self.sublayer[0](x,lambda x: self.self_attn(x,x,x,mask))
        return self.sublayer[1](x,self.feed_forward)

# --------------------------------Decoder-----------------------------------
class Decoder(nn.Module):
    def __init__(self,layer,N):
        super(Decoder,self).__init__()
        self.layers = clones(layer,N)
        self.norm = LayerNorm(layer.size)
    
    def forward(self,x,memory,src_mask,tgt_mask):
        for layer in self.layers:
            x = layer(x,memory,src_mask,tgt_mask)
        return self.norm(x)

class DecoderLayer(nn.Module):
    def __init__(self,size,self_attn,src_attn,feed_forward,dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self,x,memory,src_mask,tgt_mask):
        m = memory
        x = self.sublayer[0](x,lambda x: self.self_attn(x,x,x,tgt_mask)) ## 为什么，这里的两个mask起到什么作用,推理时x是单个词汇吗
        x = self.sublayer[1](x,lambda x: self.src_attn(x,m,m,src_mask)) # 为什么memory是什么
        return self.sublayer[2](x,self.feed_forward)
    
class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(self.w_1(x).relu()))

class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

class PositionalEncoding(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)



def make_model(
    src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1
):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab),
    )

    # This was important from their code.
    # Initialize parameters with Glorot / fan_avg.
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model

def inference_test():
    test_model = make_model(11, 11, 2)
    test_model.eval()
    src = torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
    src_mask = torch.ones(1, 1, 10)

    
    memory = test_model.encode(src, src_mask)
    ys = torch.zeros(1, 1).type_as(src)

    for i in range(9):
        out = test_model.decode(
            memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data)
        )
        ipdb.set_trace()
        prob = test_model.generator(out[:, -1])
        ipdb.set_trace()
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        ys = torch.cat(
            [ys, torch.empty(1, 1).type_as(src.data).fill_(next_word)], dim=1
        )

    print("Example Untrained Model Prediction:", ys)

def show_example(fn, args=[]):
    if __name__ == "__main__" and RUN_EXAMPLES:
        return fn(*args)
    
def run_tests():
    for _ in range(10):
        inference_test()


show_example(run_tests)