In [2]:
import math
import collections
import torch
from torch import nn
from torch.nn import functional as F
import matplotlib.pyplot as plt

In [3]:
class Encoder(nn.Module):
    def __init__(self,**kwargs) -> None:
        super(Encoder,self).__init__(**kwargs)
    def forward(self,X,*args):
        raise NotImplementedError

In [4]:
class Decoder(nn.Module):
    def __init__(self,**kwargs) -> None:
        super(Decoder,self).__init__(**kwargs)
    def init_state(self,enc_outputs,*args):
        raise NotImplementedError
    def forward(self,X,state):
        raise NotImplementedError

In [5]:
class EncoderDecoder(nn.Module):
    def __init__(self,encoder,decoder,**kwargs) -> None:
        super(EncoderDecoder,self).__init__(**kwargs)
        self.encoder=encoder
        self.decoder=decoder
    def forward(self,enc_X,dec_X,*args):
        enc_outputs=self.encoder(enc_X,*args)
        dec_state=self.decoder.init_state(enc_outputs,*args)
        return self.decoder(dec_X,dec_state)

In [None]:
def masked_softmax(X,valid_lens):
    if valid_lens is None:
        return nn.functional.softmax(X,dim=-1)
    else:
        shape=X.shape
        if valid_lens.dim()==1:
            valid_lens=torch.repeat_interleave(valid_lens,shape[-1])
        else:
            valid_lens=valid_lens.reshape(-1)
        return nn.functional.softmax()
        

In [None]:
class Additiveattention(nn.Module):
    def __init__(self,key_size,query_size,num_hiddens,dropout,**kwargs) -> None:
        super(Additiveattention,self).__init__(**kwargs)
        self.W_k=nn.Linear(key_size,num_hiddens,bias=False)
        self.W_q=nn.Linear(query_size,num_hiddens,bias=False)
        self.W_v=nn.Linear(num_hiddens,1,bias=False)
        self.dropout=nn.Dropout(dropout)
    def forward(self,queries,keys,values,valid_lens):
        queries,keys=self.W_q(queries),self.W_k(keys)
        features=queries.unsqueeze(2)+keys.unsqueeze(1)
        features=torch.tanh(features)
        scores=self.W_v(features).squeeze(-1)
        

In [None]:
class Attentiondecode(Decoder):
    def __init__(self, **kwargs) -> None:
        super(Attentiondecode,self).__init__(**kwargs)
    

In [None]:
class Seq2seqattentiondecoder(Attentiondecode):
   def __init__(self,vocab_size,embed_size,num_hiddens,num_layer,dropout=0,**kwargs) -> None:
        super(Seq2seqattentiondecoder,self).__init__(**kwargs)
        self.embedding=nn.Embedding(vocab_size,embed_size)
        self.rnn=nn.GRU(embed_size, num_hiddens,num_layer,dropout=dropout)
        self.attention
       

In [7]:
class Seq2seqencoder(Encoder):
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layer,dropout=0,**kwargs) -> None:
        super(Seq2seqencoder,self).__init__(**kwargs)
        self.embedding=nn.Embedding(vocab_size,embed_size)
        self.rnn=nn.GRU(embed_size, num_hiddens,num_layer,dropout=dropout)
    def forward(self,X,*args):
        X=self.embedding(X)
        X=X.permute(1,0,2)
        output,state=self.rnn(X)
        return output,state

In [9]:
encoder=Seq2seqencoder(vocab_size=10,embed_size=8,num_hiddens=16,num_layer=2)
encoder.eval()
X=torch.zeros((4,7),dtype=torch.long)
output,state=encoder(X)
output.shape


torch.Size([7, 4, 16])

In [10]:
state.shape

torch.Size([2, 4, 16])

In [11]:
class Seq2seqdecoder(Decoder):
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layer,dropout=0,**kwargs) -> None:
        super(Seq2seqdecoder,self).__init__(**kwargs)
        self.embedding=nn.Embedding(vocab_size,embed_size)
        self.rnn=nn.GRU(embed_size+num_hiddens, num_hiddens,num_layer,dropout=dropout)
        self.dense=nn.Linear(num_hiddens,vocab_size)
    def init_state(self, enc_outputs, *args):
        return enc_outputs[1]
    def forward(self,X,state):
        X=self.embedding(X).permute(1,0,2)
        # 最浓缩的信息在最上面一层最后一个时间的输出
        context=state[-1].repeat(X.shape[0],1,1)
        X_and_context=torch.cat((X,context),2)
        output,state=self.rnn(X_and_context,state)
        output=self.dense (output ).permute(1,0,2)
        return output,state


In [12]:
decoder=Seq2seqdecoder(vocab_size=10,embed_size=8,num_hiddens=16,num_layer=2)
decoder.eval()

state=decoder.init_state(encoder(X))
output,state=decoder(X,state)
output.shape,state.shape

(torch.Size([4, 7, 10]), torch.Size([2, 4, 16]))

In [13]:
def sequence_mask(X,valid_len,value=0):
    maxlen=X.size(1)
    mask=torch.arange((maxlen),dtype=torch.float32,device=X.device)[None,:]<valid_len[:,None]
    X[~mask]=value
    return X

In [14]:
class MasksoftmaxCELoss(nn.CrossEntropyLoss):
    def forward(self,pred,label,valid_len):
        weights=torch.ones_like(label)
        weights=sequence_mask(weights,valid_len)
        self.reduction='none'
        unweighted_loss=super().forward(pred.permute(0,2,1),label)
        weighted_loss=(unweighted_loss*weights).mean(dim=1)
        return weighted_loss
