## 简介
当输入序列长度和输出标签序列长度都可变时(即输入序列的子序列个数和输出序列的子序列个数不同)

数据预处理:
指定子序列长度,填充<pad>词元或裁剪得到固定长度的子序列,增加<eos>词元至子序列末尾

encoder(编码器):
1.子序列(batchsize,num_steps,vocabsize)用embeding层降维至(batchsize,num_steps,embed_size) #num_steps是子序列的词元个数,即RNN的时间步骤个数
2.子序列(batchsize,num_steps,vocabsize)输出为(batchsize,num_steps,num_hiddens)和每个子序列的末尾时间步骤的状态(batchsize,num_layers,num_hiddens) #num_hidddens是隐藏层特征数,num_layers是RNN层数

decoder(编码器):(N个长度固定的子序列->batchsize个为一批次,子序列长度为num_steps)
1.子序列(batchsize,num_steps,vocabsize)用embeding层降维至(batchsize,num_steps,embed_size) #num_steps是子序列的词元个数,即RNN的时间步骤个数
2.取encoder输出的最后一个时间步骤的预测的特征(batchsize,num_hiddens),添加为decoder子序列的特征(batchsize,num_steps,embed_size+num_hiddens) #decoder的每个时间步都包含encoder的特征?
3.拼接后的decoder子序列(batchsize,num_steps,embed_size+num_hiddens)的RNN输出为(batchsize,num_steps,num_hiddens),并通过linear层转为(batchsize,num_steps,vocabsize),其中RNN初始状态是encoder的末尾时间步骤的状态(batchsize,num_layers,num_hiddens) 


In [1]:
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

  from .autonotebook import tqdm as notebook_tqdm


## 模型定义

In [None]:
class Lit_encoder(pl.lightning.LightningModule):
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0):
        super(Lit_encoder,self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size,embed_size)
        self.rnn = torch.nn.GRU(embed_size,num_hiddens,num_layers,batch_first=True,dropout=dropout)
    
    def forward(self,x):
        x = self.embedding(x)
        outputs,state = self.rnn(x)
        #outputs: (batch_size,num_steps,num_hiddens)
        #state: (batch_size,num_layers,num_hiddens)
        return outputs, state

class Lit_decoder(pl.lightening.LightningModule):
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0):
        super(Lit_decoder,self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size,embed_size)
        self.rnn = torch.nn.GRU(embed_size+num_hiddens,num_hiddens,num_layers,batch_first=True,dropout=dropout)
        self.dense = torch.nn.Linear(num_hiddens,vocab_size)
    
    def forward(self,x,state):
        x = self.embedding(x)
        enc_outputs, enc_state = state
        context=enc_outputs[-1]
        context=context.repeat(x.shape[1],1,1)
        x=torch.cat((x,context),2)
        outputs,dec_state = self.rnn(x,enc_state)
        outputs=self.dense(outputs)
        return outputs, [enc_outputs,dec_state]

class Lit_encoder_decoder(pl.lightning.LightningModule):
    def __init__(self,encoder,decoder):
        super(Lit_encoder_decoder,self).__init__()
        self.encoder=encoder
        self.decoder=decoder

    def forward(self,enc_x,dec_x):
        enc_result=self.encoder(enc_x)
        dec_result=self.decoder(dec_x,enc_result)
        return dec_result[0]

    def training_step(self,batch,batch_idx):
        x,y=batch
        y_pred,_=self(x,y) #
        loss=torch.nn.functional.cross_entropy(y_pred.view(-1,y_pred.shape[-1]),y.view(-1))
        self.log('train_loss',loss, prog_bar=True, logger=True, on_epoch=True,on_step=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),lr=0.001)



## 数据集

In [54]:
import requests
import os
import re
import zipfile
class LitLoadData_fra(pl.LightningDataModule):  
    def prepare_data(self):
        url = 'http://d2l-data.s3-accelerate.amazonaws.com/fra-eng.zip'
        #文件是否存在
        if os.path.exists('./data/fra-eng/fra.txt'):
            return
        #下载文件
        r = requests.get(url, stream=True)
        #解压文件
        with zipfile.ZipFile('./data/fra-eng.zip', 'r') as zip_ref:
            zip_ref.extractall('./data/fra-eng')
            
data=LitLoadData_fra()
data.prepare_data()
#返回fra.txt内容
with open('./data/fra-eng/fra.txt', 'r', encoding='utf-8') as f:
    raw_txt = f.read()
print(raw_txt[:75])

Go.	Va !
Hi.	Salut !
Run!	Cours !
Run!	Courez !
Who?	Qui ?
Wow!	Ça alors !



In [None]:
def preprocess(raw_txt,max_tokens=10000,num_steps=9):
    #大写字母改为小写
    raw_txt=raw_txt.lower()
    #去掉空行,取前max_tokens行
    lines=raw_txt.split('\n')
    lines=[line for line in lines if len(line)>0]
    lines=lines[:max_tokens]
    #每行以 tab 分割为两组
    pairs=[line.split('\t') for line in lines]
    #删除空行
    pairs=[pair for pair in pairs if len(pair)==2]
    #每组单词分割,标点符号视为一个独立单词
    pairs=[[re.findall(r'\w+|[^\w\s]',pair[0]),re.findall(r'\w+|[^\w\s]',pair[1])] for pair in pairs]
    #返回源语言和目标语言
    src=[pair[0] for pair in pairs]
    tgt=[pair[1] for pair in pairs]
    #末尾添加特殊字符'<eos>'
    src=[pair+['<eos>'] for pair in src]
    tgt=[pair+['<eos>'] for pair in tgt]
    #tgt前面添加特殊字符'<bos>'
    tgt=[['<bos>']+pair for pair in tgt]
    #裁剪或填充'<pad>'至num_steps
    src=[pair[:num_steps]+['<pad>']*(num_steps-len(pair)) if len(pair)<num_steps else pair[:num_steps] for pair in src]
    tgt=[pair[:num_steps]+['<pad>']*(num_steps-len(pair)) if len(pair)<num_steps else pair[:num_steps] for pair in tgt]
    return src,tgt


LitLoadData_fra.preprocess=preprocess 

src,tgt=preprocess(raw_txt)
print(src[:6])
print(tgt[:6])

[['go', '.', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'], ['hi', '.', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'], ['run', '!', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'], ['run', '!', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'], ['who', '?', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'], ['wow', '!', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']]
[['<bos>', 'va', '!', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'], ['<bos>', 'salut', '!', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'], ['<bos>', 'cours', '!', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'], ['<bos>', 'courez', '!', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'], ['<bos>', 'qui', '?', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>'], ['<bos>', 'ça', 'alors', '!', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>']]


In [None]:
import collections

def vocab(sentences,min_freq=0):
    tokens=[token for sentence in sentences for token in sentence]
    counter=collections.Counter(tokens)
    #去掉频率小于min_freq的单词
    tokens=[token for token in counter if counter[token]>=min_freq]
    #token的idx按频率降序
    tokens=sorted(tokens,key=lambda x:counter[x],reverse=True)
    idx_to_token=[['<unk>']+token for token in tokens]
    token_to_idx={token:idx for idx,token in enumerate(idx_to_token)}
    return idx_to_token,token_to_idx

### 用tokenizers训练文本自动提取token

In [None]:
from tokenizers.pre_tokenizers import BertPreTokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer

class TokenizerTrainer:
    def __init__(self, num_step=9):
        self.num_step = num_step
        self.tokenizer = Tokenizer(BPE())
        self.tokenizer.pre_tokenizer = BertPreTokenizer()
        self.tokenizer.add_special_tokens(["<pad>", "<bos>", "<eos>", "<unk>"])
        self.trainer = BpeTrainer(special_tokens=["<pad>", "<bos>", "<eos>", "<unk>"], min_frequency=2)
        self.tokenizer.enable_padding(pad_id=self.tokenizer.token_to_id("<pad>"), pad_token="<pad>", length=self.num_step)
        self.tokenizer.enable_truncation(max_length=self.num_step)
        self.tokenizer.post_processor = TemplateProcessing(
            single="<bos> $A <eos>",
            pair="<bos> $A <eos> <bos> $B <eos>",
            special_tokens=[
                ("<bos>", self.tokenizer.token_to_id("<bos>")),
                ("<eos>", self.tokenizer.token_to_id("<eos>")),
            ],
        )

    def train(self, file_path):
        with open(file_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        pairs = [line.split('\t') for line in lines]
        self.tokenizer.train_from_iterator(pairs, self.trainer)

    def save(self, path):
        self.tokenizer.save(path)

    def encode(self, text):
        return self.tokenizer.encode(text)

    def decode(self, ids):
        return self.tokenizer.decode(ids)


test_text: Go.
Encoded: ['<bos>', 'Go', '.', '<eos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
Decoded: Go .
