實驗：測試強化學習是否有效

real_a -> fake_a'
* Discrimintor: 區分為A or B，越接近B，分數越高
* loss = lambda_a * NLL(fake_a, real_a) + lambda_b * Reward
* 概念：一方面企圖讓G直接還原A，一方面用D來讓G生成像是B的sample

In [8]:
import logging
import os
import sys
import argparse

import torch
import torch.optim as optim
import torch.nn as nn
from torchtext import data
from torchtext import datasets

try:
    pardir = os.path.realpath(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
except NameError:
    pardir = os.path.split(os.getcwd())[0]
if pardir not in sys.path:
    sys.path.insert(0, pardir)

from seq2seq.util.checkpoint import Checkpoint

from ape import Constants, options
from ape.dataset.lang8 import Lang8
from ape.dataset.field import SentencePieceField
from ape.model.discriminator import BinaryClassifierCNN
from ape.model.transformer.Models import Transformer
from ape import trainers

/Users/chi/Work/pytorch-seq2seq


### 設定

In [None]:
LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'

parser = argparse.ArgumentParser()
opt = options.train_options(parser)
opt = parser.parse_args(args=[])

opt.cuda = torch.cuda.is_available()
opt.device = None if opt.cuda else -1

# 快速變更設定
opt.exp_dir = './experiment/transformer-dualgan/use_billion'
opt.load_vocab_from = './experiment/transformer/lang8-cor2err/vocab.pt'
opt.build_vocab_from = './data/billion/billion.30m.model.vocab'

opt.exp_dir = os.path.join(pardir, opt.exp_dir)
opt.load_vocab_from = os.path.join(pardir, opt.load_vocab_from)
opt.build_vocab_from = os.path.join(pardir, opt.build_vocab_from)

# dataset params
opt.max_len = 20

# G params
opt.load_G_a_from = './experiment/transformer/lang8-err2cor/'
opt.load_G_b_from = './experiment/transformer/lang8-cor2err/'
opt.d_model = 300  # 暫時需要

# D params
opt.embed_dim = opt.d_model
opt.num_kernel = 100
opt.kernel_sizes = [2, 3, 4, 5, 6, 7]
opt.dropout_p = 0.25

# train params
opt.batch_size = 1
opt.n_epoch = 5

if not os.path.exists(opt.exp_dir):
    os.makedirs(opt.exp_dir)

logging.basicConfig(filename=opt.exp_dir + '/.log',
                    format=LOG_FORMAT, level=logging.DEBUG)
logging.getLogger().addHandler(logging.StreamHandler())

logging.info('Use CUDA? ' + str(opt.cuda))
logging.info(opt)

Use CUDA? False
Use CUDA? False
Use CUDA? False
Namespace(batch_size=1, beam_size=5, build_vocab_from='/Users/chi/Work/pytorch-seq2seq/./data/billion/billion.30m.model.vocab', cuda=False, d_inner_hid=1024, d_k=64, d_model=300, d_v=64, d_word_vec=512, device=-1, dropout=0.1, dropout_p=0.25, embed_dim=300, embs_share_weight=False, epoch=10, exp_dir='/Users/chi/Work/pytorch-seq2seq/./experiment/transformer-dualgan/use_billion', kernel_sizes=[2, 3, 4, 5, 6, 7], load_G_a_from='./experiment/transformer/lang8-err2cor/', load_G_b_from='./experiment/transformer/lang8-cor2err/', load_from=None, load_vocab_from='/Users/chi/Work/pytorch-seq2seq/./experiment/transformer/lang8-cor2err/vocab.pt', max_len=20, max_word_seq_len=50, n_best=1, n_epoch=5, n_head=8, n_layers=6, n_warmup_steps=4000, num_kernel=100, proj_share_weight=False)
Namespace(batch_size=1, beam_size=5, build_vocab_from='/Users/chi/Work/pytorch-seq2seq/./data/billion/billion.30m.model.vocab', cuda=False, d_inner_hid=1024, d_k=64, d_mod

### 準備資料

In [19]:
def len_filter(example):
    return len(example.src) <= opt.max_len and len(example.tgt) <= opt.max_len

EN = SentencePieceField(init_token=Constants.BOS_WORD,
                        eos_token=Constants.EOS_WORD,
                        batch_first=True)

train = datasets.TranslationDataset(
    path=os.path.join(pardir, './data/dualgan/train',),
    exts=('.billion.sp', '.use.sp'), fields=[('src', EN), ('tgt', EN)],
    filter_pred=len_filter)

# 讀取 vocabulary（確保一致）
try:
    logging.info('Load voab from %s' % opt.load_vocab_from)
    EN.load_vocab(opt.load_vocab_from)
except FileNotFoundError:
    EN.build_vocab_from(opt.build_vocab_from)
    EN.save_vocab(opt.load_vocab_from)

logging.info('Vocab len: %d' % len(EN.vocab))

# 檢查Constants是否有誤
assert EN.vocab.stoi[Constants.BOS_WORD] == Constants.BOS
assert EN.vocab.stoi[Constants.EOS_WORD] == Constants.EOS
assert EN.vocab.stoi[Constants.PAD_WORD] == Constants.PAD
assert EN.vocab.stoi[Constants.UNK_WORD] == Constants.UNK

Load voab from /Users/chi/Work/pytorch-seq2seq/./experiment/transformer/lang8-cor2err/vocab.pt
Load voab from /Users/chi/Work/pytorch-seq2seq/./experiment/transformer/lang8-cor2err/vocab.pt
Load voab from /Users/chi/Work/pytorch-seq2seq/./experiment/transformer/lang8-cor2err/vocab.pt
Vocab len: 8003
Vocab len: 8003
Vocab len: 8003


### 初始化model

In [20]:
G_a = load_G(opt.load_G_a_from)
G_b = load_G(opt.load_G_b_from)
D_a = build_D(opt, EN)
D_b = build_D(opt, EN)

optim_G_a = optim.Adam(G_a.get_trainable_parameters(),
                       betas=(0.9, 0.98), eps=1e-09)
optim_G_b = optim.Adam(G_a.get_trainable_parameters(),
                       betas=(0.9, 0.98), eps=1e-09)
optim_D_a = torch.optim.Adam(D_a.parameters(), lr=1e-4)
optim_D_b = torch.optim.Adam(D_b.parameters(), lr=1e-4)

def get_criterion(vocab_size):
    ''' With PAD token zero weight '''
    weight = torch.ones(vocab_size)
    weight[Constants.PAD] = 0
    return nn.CrossEntropyLoss(weight, size_average=False)

crit_G = get_criterion(len(EN.vocab))
crit_D = nn.BCELoss()

if opt.cuda:
    G_a.cuda()
    G_b.cuda()
    D_a.cuda()
    D_b.cuda()
    crit_G.cuda()
    crit_D.cuda()

NameError: name 'load_G' is not defined

### 訓練

In [None]:
trainer_G = trainers.TransformerTrainer()
trainer = trainers.DualGanPGTrainer(
    opt,
    trainer_G=trainer_G,
    trainer_D=trainers.DiscriminatorTrainer())

def eval_G(model):
    _, val_iter = data.BucketIterator.splits(
        (train_lang8, val_lang8), batch_sizes=(opt.batch_size, 128), device=opt.device,
        sort_key=lambda x: len(x.src), repeat=False)
    trainer_G.evaluate(model, val_iter, crit_G, EN)

for epoch in range(10):
    logging.info('[Epoch %d]' % epoch)

    train_iter = data.BucketIterator(
        dataset=train, batch_size=opt.batch_size, device=opt.device,
        sort_key=lambda x: len(x.src), repeat=False)
    # batch = next(iter(train_iter))
    # src_seq = batch.src
    # tgt_seq = batch.tgt

    trainer.train(
        0,
        train_iter,
        G_a=G_a,
        G_b=G_b,
        D_a=D_a,
        D_b=D_b,
        optim_G_a=optim_G_a,
        optim_G_b=optim_G_b,
        optim_D_a=optim_D_a,
        optim_D_b=optim_D_b,
        crit_G=crit_G,
        crit_D=crit_D,
        eval_G=eval_G,
        A_FIELD=EN,
        B_FIELD=EN)