# seq2seq模型测试
---

数据集构建方案不同，使用更复杂的模型。

In [None]:
import os
import re
from tqdm import tqdm
import sys
import random
import pprint
import torch
import torch.nn as nn


sys.path.insert(0, "/home/team55/notespace/zengbin")

import jddc.utils as u
import jddc.datasets as d
from seq2seq.fields import *
from seq2seq.optim import Optimizer
from seq2seq.models import EncoderRNN, DecoderRNN, Seq2seq
from seq2seq.loss import NLLLoss
from seq2seq.supervised_trainer import SupervisedTrainer

## 参数配置

---

In [None]:
class Seq2SeqConfig(object):
    """Seq2Seq模型参数配置"""
    use_cuda = True
    device = 1
    teacher_forcing_ratio = 0.5

    # encoder & decoder
    hidden_size = 256
    n_layers = 4
    bidirectional = True
    max_len = 300
    rnn_cell = 'lstm'

    encoder_params = u.AttrDict()
    encoder_params['hidden_size'] = hidden_size
    encoder_params['n_layers'] = n_layers
    encoder_params['bidirectional'] = bidirectional
    encoder_params['max_len'] = max_len
    encoder_params['rnn_cell'] = rnn_cell
    encoder_params['variable_lengths'] = True
    encoder_params['input_dropout_p'] = 0
    encoder_params['dropout_p'] = 0.5

    decoder_params = u.AttrDict()
    decoder_params['hidden_size'] = hidden_size*2 if bidirectional else hidden_size
    decoder_params['n_layers'] = n_layers
    decoder_params['bidirectional'] = bidirectional
    decoder_params['max_len'] = max_len
    decoder_params['rnn_cell'] = rnn_cell
    decoder_params['use_attention'] = True
    decoder_params['device'] = device
    decoder_params['input_dropout_p'] = 0
    decoder_params['dropout_p'] = 0.5

    def __init__(self):
        # 模型存储目录
        self.s2s_path = os.path.join("/home/team55/notespace/data", "seq2seq02")
        u.insure_folder_exists(self.s2s_path)
        self.file_train = os.path.join(self.s2s_path, "train.tsv")
        # 翻转QQ分词结果
        self.file_train_rq = os.path.join(self.s2s_path, "train_reverse_q.tsv")
        self.log_file = os.path.join(self.s2s_path, "seq2seq_02.log")

In [None]:
conf = Seq2SeqConfig()
logger = u.create_logger(conf.log_file, name="s2s", cmd=True)

## 加载数据
---

In [None]:
# solution for  _csv.Error: field larger than field limit (131072)
import csv
csv.field_size_limit(500 * 1024 * 1024)

src = SourceField(batch_first=True)
tgt = TargetField(batch_first=True)
max_len = conf.max_len

def len_filter(example):
    return len(example.src) <= max_len and len(example.tgt) <= max_len

train = torchtext.data.TabularDataset(
    path=conf.file_train_rq, format='tsv',
    fields=[('src', src), ('tgt', tgt)],
    filter_pred=len_filter
)

src.build_vocab(train, max_size=100000)
tgt.build_vocab(train, max_size=100000)
input_vocab = src.vocab
output_vocab = tgt.vocab

## 创建模型
---

In [None]:
print(conf.encoder_params)
print(conf.decoder_params)

In [None]:
loss = NLLLoss()
encoder = EncoderRNN(vocab_size=len(src.vocab), **conf.encoder_params)
decoder = DecoderRNN(vocab_size=len(tgt.vocab), eos_id=tgt.eos_id, sos_id=tgt.sos_id, **conf.decoder_params)
seq2seq = Seq2seq(encoder, decoder)

device = conf.device
if conf.use_cuda:
    seq2seq.cuda(device)
    loss.cuda(device)

for param in seq2seq.parameters():
    param.data.uniform_(-0.08, 0.08)

## 训练模型
---

In [None]:
# Optimizer and learning rate scheduler can be customized by
# # explicitly constructing the objects and pass to the trainer.
optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5)
# scheduler = StepLR(optimizer.optimizer, 1)
# optimizer.set_scheduler(scheduler)
# train
trainer = SupervisedTrainer(loss=loss, batch_size=32, checkpoint_every=500, print_every=10,
                            expt_dir=conf.s2s_path, random_seed="1234", device=conf.device)
trainer.logger = logger
seq2seq = trainer.train(seq2seq, train, num_epochs=3, optimizer=optimizer, teacher_forcing_ratio=0.5, resume=False)